tensorrtx搭建Zero-DCE部署

Zero-DCE介绍

paper作者Zero-DCE主页:https://li-chongyi.github.io/Proj_Zero-DCE.html

Zero-DCE使用深度学习方法参考了深度曲线估计,通过一个轻量的深度卷积神经网络设计了一个光线增强曲线,对微光图像进行增强,可以将不同的灯光条件下采集的光照不均匀和弱光的图像进行调整。
对输入图像的每个信道分别做迭代操作,每次迭代操作的输出和输入图像map层再次结合作为下层输入。

tensortx改写网络结构

pytorch实现:https://github.com/Li-Chongyi/Zero-DCE

class enhance_net_nopool(nn.Module):
	def __init__(self):
		super(enhance_net_nopool, self).__init__()
		self.relu = nn.ReLU(inplace=True)
		number_f = 32
		self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True) 
		self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) 
		self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) 
		self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) 
		self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) 
		self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) 
		self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True) 
		self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
		self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
		
	def forward(self, x):
		x1 = self.relu(self.e_conv1(x))
		x2 = self.relu(self.e_conv2(x1))
		x3 = self.relu(self.e_conv3(x2))
		x4 = self.relu(self.e_conv4(x3))
		x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
		x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))
		x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))
		r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)
		x = x + r1*(torch.pow(x,2)-x)
		x = x + r2*(torch.pow(x,2)-x)
		x = x + r3*(torch.pow(x,2)-x)
		enhance_image_1 = x + r4*(torch.pow(x,2)-x)		
		x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)		
		x = x + r6*(torch.pow(x,2)-x)	
		x = x + r7*(torch.pow(x,2)-x)
		enhance_image = x + r8*(torch.pow(x,2)-x)
		r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)
		return enhance_image_1,enhance_image,r

Zero-DCE是一个轻量网络,使用到的算子也不复杂,算子都为tensorrt已经支持的操作,我们可以直接调用tensorrt的操作来部署,不需要重新自定义算子。

  auto pow = network->addElementWise(*data, *data, ElementWiseOperation::kPROD);
  auto sub = network->addElementWise(*pow->getOutput(0), *data, ElementWiseOperation::kSUB);
  //e_conv1
  IConvolutionLayer* conv1 = network->addConvolutionNd(*data, 32, DimsHW{ 3, 3 }, weightMap["e_conv1.weight"], weightMap["e_conv1.bias"]);
  assert(conv1);
  conv1->setPaddingNd(DimsHW{ 1, 1 });
  conv1->setStrideNd(DimsHW{ 1,1 });
  IActivationLayer* relu1 = network->addActivation(*conv1->getOutput(0), ActivationType::kRELU);
  assert(relu1);
  //e_conv2
  IConvolutionLayer* conv2 = network->addConvolutionNd(*relu1->getOutput(0), 32, DimsHW{ 3, 3 }, weightMap["e_conv2.weight"], weightMap["e_conv2.bias"]);
  conv2->setPaddingNd(DimsHW{ 1, 1 });
  conv2->setStrideNd(DimsHW{ 1,1 });
  IActivationLayer* relu2 = network->addActivation(*conv2->getOutput(0), ActivationType::kRELU);
  //e_conv3
  IConvolutionLayer* conv3 = network->addConvolutionNd(*relu2->getOutput(0), 32, DimsHW{ 3, 3 }, weightMap["e_conv3.weight"], weightMap["e_conv3.bias"]);
  conv3->setPaddingNd(DimsHW{ 1, 1 });
  conv3->setStrideNd(DimsHW{ 1,1 });
  IActivationLayer* relu3 = network->addActivation(*conv3->getOutput(0), ActivationType::kRELU);
  //e_conv4
  IConvolutionLayer* conv4 = network->addConvolutionNd(*relu3->getOutput(0), 32, DimsHW{ 3, 3 }, weightMap["e_conv4.weight"], weightMap["e_conv4.bias"]);
  conv4->setPaddingNd(DimsHW{ 1, 1 });
  conv4->setStrideNd(DimsHW{ 1,1 });
  IActivationLayer* relu4 = network->addActivation(*conv4->getOutput(0), ActivationType::kRELU);
  //concat relu3 and relu4
  ITensor* inputTensors34[] = { relu3->getOutput(0), relu4->getOutput(0) };
  auto cat34 = network->addConcatenation(inputTensors34, 2);
  //e_conv5
  IConvolutionLayer* conv5 = network->addConvolutionNd(*cat34->getOutput(0), 32, DimsHW{ 3, 3 }, weightMap["e_conv5.weight"], weightMap["e_conv5.bias"]);
  conv5->setPaddingNd(DimsHW{ 1, 1 });
  conv5->setStrideNd(DimsHW{ 1,1 });
  IActivationLayer* relu5 = network->addActivation(*conv5->getOutput(0), ActivationType::kRELU);
  //concat relu2 and relu5
  ITensor* inputTensors25[] = { relu2->getOutput(0), relu5->getOutput(0) };
  auto cat25 = network->addConcatenation(inputTensors25, 2);
  //e_conv6
  IConvolutionLayer* conv6 = network->addConvolutionNd(*cat25->getOutput(0), 32, DimsHW{ 3, 3 }, weightMap["e_conv6.weight"], weightMap["e_conv6.bias"]);
  conv6->setPaddingNd(DimsHW{ 1, 1 });
  conv6->setStrideNd(DimsHW{ 1,1 });
  IActivationLayer* relu6 = network->addActivation(*conv6->getOutput(0), ActivationType::kRELU);
  //concat relu1 and relu6
  ITensor* inputTensors16[] = { relu1->getOutput(0), relu6->getOutput(0) };
  auto cat16 = network->addConcatenation(inputTensors16, 2);
  //e_conv7
  IConvolutionLayer* conv7 = network->addConvolutionNd(*cat16->getOutput(0), 24, DimsHW{ 3, 3 }, weightMap["e_conv7.weight"], weightMap["e_conv7.bias"]);
  conv7->setPaddingNd(DimsHW{ 1, 1 });
  conv7->setStrideNd(DimsHW{ 1,1 });
  IActivationLayer* relu7 = network->addActivation(*conv7->getOutput(0), ActivationType::kTANH);
  //addSlice
  Dims d = relu7->getOutput(0)->getDimensions();
  ISliceLayer* slice0 = network->addSlice(*relu7->getOutput(0), Dims3{ 0,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 });
  ISliceLayer* slice1 = network->addSlice(*relu7->getOutput(0), Dims3{ 1,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 });
  ISliceLayer* slice2 = network->addSlice(*relu7->getOutput(0), Dims3{ 2,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 });
  ISliceLayer* slice3 = network->addSlice(*relu7->getOutput(0), Dims3{ 3,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 });
  ISliceLayer* slice4 = network->addSlice(*relu7->getOutput(0), Dims3{ 4,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 });
  ISliceLayer* slice5 = network->addSlice(*relu7->getOutput(0), Dims3{ 5,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 });
  ISliceLayer* slice6 = network->addSlice(*relu7->getOutput(0), Dims3{ 6,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 });
  ISliceLayer* slice7 = network->addSlice(*relu7->getOutput(0), Dims3{ 7,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 });
  //split
  auto mul = network->addElementWise(*slice0->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD);
  auto add = network->addElementWise(*data, *mul->getOutput(0), ElementWiseOperation::kSUM);
  pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD);
  sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB);
  mul = network->addElementWise(*slice1->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD);
  add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM);
  pow = pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD);
  sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB);
  mul = network->addElementWise(*slice2->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD);
  add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM);
  pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD);
  sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB);
  mul = network->addElementWise(*slice3->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD);
  add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM);
  pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD);
  sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB);
  mul = network->addElementWise(*slice4->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD);
  add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM);
  pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD);
  sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB);
  mul = network->addElementWise(*slice5->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD);
  add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM);
  pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD);
  sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB);
  mul = network->addElementWise(*slice6->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD);
  add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM);
  pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD);
  sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB);
  mul = network->addElementWise(*slice7->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD);
  add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM);
  add->getOutput(0)->setName(OUTPUT_BLOB_NAME);

Zero-DCE使用的tensorrt算子:

tensorrtpytorch
addElementWise(ElementWiseOperation::kSUM)+
addElementWise(ElementWiseOperation::kSUB)
addElementWise(ElementWiseOperation::kPROD)torch.pow
addConvolutionNdConv2d
addActivation(ActivationType::kRELU)ReLU
addActivation(ActivationType::kTANH)tanh
addConcatenationtorch.cat
addSlicesplit

部署Zero-DCE

tensorrtx部署步骤参考:https://github.com/wang-xinyu/tensorrtx

将Zero-DCE模型转换为wts权重文件

python gen_wts.py -w zero_dce.pt -o zero_dec.wts

gen_wts.py

import sys
import argparse
import os
import struct
import torch
from utils.torch_utils import select_device

def parse_args():
    parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
    parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')
    parser.add_argument('-o', '--output', help='Output (.wts) file path (optional)')
    args = parser.parse_args()
    if not os.path.isfile(args.weights):
        raise SystemExit('Invalid input file')
    if not args.output:
        args.output = os.path.splitext(args.weights)[0] + '.wts'
    elif os.path.isdir(args.output):
        args.output = os.path.join(
            args.output,
            os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
    return args.weights, args.output

pt_file, wts_file = parse_args()
# Initialize
device = select_device('cpu')
# Load model
model = torch.load(pt_file, map_location=device)['model'].float()  # load to FP32
# update anchor_grid info
anchor_grid = model.model[-1].anchors * model.model[-1].stride[...,None,None]
# model.model[-1].anchor_grid = anchor_grid
delattr(model.model[-1], 'anchor_grid')  # model.model[-1] is detect layer
model.model[-1].register_buffer("anchor_grid",anchor_grid) #The parameters are saved in the OrderDict through the "register_buffer" method, and then saved to the weight.
model.to(device).eval()
with open(wts_file, 'w') as f:
    f.write('{}\n'.format(len(model.state_dict().keys())))
    for k, v in model.state_dict().items():
        vr = v.reshape(-1).cpu().numpy()
        f.write('{} {} '.format(k, len(vr)))
        for vv in vr:
            f.write(' ')
            f.write(struct.pack('>f' ,float(vv)).hex())
        f.write('\n')

读取zero_dce.wts中权重信息,写入改写时weightMap

std::map<std::string, Weights> loadWeights(const std::string file) {
    std::cout << "Loading weights: " << file << std::endl;
    std::map<std::string, Weights> weightMap;
    // Open weights file
    std::ifstream input(file);
    assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!!!!!!");
    // Read number of weight blobs
    int32_t count;
    input >> count;
    assert(count > 0 && "Invalid weight map file.");
    while (count--)
    {
        Weights wt{ DataType::kFLOAT, nullptr, 0 };
        uint32_t size;
        // Read name and type of blob
        std::string name;
        input >> name >> std::dec >> size;
        wt.type = DataType::kFLOAT;
        // Load blob
        uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
        for (uint32_t x = 0, y = size; x < y; ++x)
        {
            input >> std::hex >> val[x];
        }
        wt.values = val;
        wt.count = size;
        weightMap[name] = wt;
    }
    return weightMap;
}

其余tensorrt初始化部分与tensorrtx其余项目类似不再赘述了。
之前部署yolov4时实现scatterNd算子花了不小的功夫(结果实现后,tensorrt22.01开始支持scatterplugin…),只要算子tensorrt已经提供支持,使用tensorrt部署还是很方便的,切记根据输入和输出大小修改开辟的空间(踩过坑…

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2022年5月13日 上午10:30
下一篇 2022年5月13日

相关推荐