Pytorch复习笔记–导出Onnx模型为动态输入和静态输入

目录


1–动态输入和静态输入

        当使用 Pytorch 将网络导出为 Onnx 模型格式时,可以导出为动态输入和静态输入两种方式。动态输入即模型输入数据的部分维度是动态的,可以由用户在使用模型时自主设定;静态输入即模型输入数据的维度是静态的,不能够改变,当用户使用模型时只能输入指定维度的数据进行推理。

        显然,动态输入的通用性比静态输入更强。

2–Pytorch API

        在 Pytorch 中,通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入,dynamic_axes 的默认值为 None,即默认为静态输入。

        以下展示动态导出的用法,通过定义 dynamic_axes 参数来设置动态导出输入。dynamic_axes 中的 0、2、3 表示相应的维度设置为动态值;

# 导出为动态输入
input_name = 'input'
output_name = 'output'
torch.onnx.export(model, 
                    input_data, 
                    "Dynamics_InputNet.onnx",
                    opset_version=11,
                    input_names=[input_name],
                    output_names=[output_name],
                    dynamic_axes={
                        input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'},
                        output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})

3–完整代码演示

        在以下代码中,定义了一个网络,并使用动态导出和静态导出两种方式,将网络导出为 Onnx 模型格式。

import torch
import torch.nn as nn

class Model_Net(nn.Module):
    def __init__(self):
        super(Model_Net, self).__init__()
        self.layer1 = nn.Sequential(

            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        
    def forward(self, data):
        data = self.layer1(data)
        return data

if __name__ == "__main__":

    # 设置输入参数
    Batch_size = 8
    Channel = 3
    Height = 256
    Width = 256
    input_data = torch.rand((Batch_size, Channel, Height, Width))

    # 实例化模型
    model = Model_Net()

    # 导出为静态输入
    input_name = 'input'
    output_name = 'output'
    torch.onnx.export(model, 
                      input_data, 
                      "Static_InputNet.onnx", 
                      verbose=True, 
                      input_names=[input_name], 
                      output_names=[output_name])

    # 导出为动态输入
    torch.onnx.export(model, 
                      input_data, 
                      "Dynamics_InputNet.onnx",
                      opset_version=11,
                      input_names=[input_name],
                      output_names=[output_name],
                      dynamic_axes={
                          input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'},
                          output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})

4–模型可视化

        通过 netron 库可视化导出的静态模型和动态模型,代码如下:

import netron

netron.start("./Dynamics_InputNet.onnx")

        静态模型可视化:

Pytorch复习笔记--导出Onnx模型为动态输入和静态输入

         动态模型可视化:

Pytorch复习笔记--导出Onnx模型为动态输入和静态输入

5–测试动态导出的Onnx模型

import numpy as np
import onnx
import onnxruntime
 
if __name__ == "__main__":
    input_data1 = np.random.rand(4, 3, 256, 256).astype(np.float32)
    input_data2 = np.random.rand(8, 3, 512, 512).astype(np.float32)
    
    # 导入 Onnx 模型
    Onnx_file = "./Dynamics_InputNet.onnx"
    Model = onnx.load(Onnx_file)
    onnx.checker.check_model(Model) # 验证Onnx模型是否准确
    
    # 使用 onnxruntime 推理
    model = onnxruntime.InferenceSession(Onnx_file, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
    input_name = model.get_inputs()[0].name
    output_name = model.get_outputs()[0].name
 
    output1 = model.run([output_name], {input_name:input_data1})
    output2 = model.run([output_name], {input_name:input_data2})
 
    print('output1.shape: ', np.squeeze(np.array(output1), 0).shape)
    print('output2.shape: ', np.squeeze(np.array(output2), 0).shape)

Pytorch复习笔记--导出Onnx模型为动态输入和静态输入

         由输出结果可知,对应动态输入 Onnx 模型,其输出维度也是动态的,并且为对应关系,则表明导出的 Onnx 模型无误。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
乘风的头像乘风管理团队
上一篇 2023年3月25日 上午10:13
下一篇 2023年3月25日 上午10:14

相关推荐