怎样c++调用pytorch训练的模型

概括–常用思路:
思路1) pytorch框架模型转libtorch框架模型;
思路2) 将pytorch下.pt模型先转通用的.onnx模型,再使用tensorrt加速工具转.engine模型
(注:不同平台下的加速工具不同,例如Nvidia家tensorRT、Rockchip家RKNN)

一、思路1 pytorch环境模型转libtorch环境模型;
1、模型转换
首先在pytorch环境下,使用torch.jit.trace()torch.jit.scrpit方法,生成libtorch环境需要的.pt模型。
下面以**torch.jit.trace()**为例:

if __name__ == '__main__':
    args = get_parser().parse_args()
    cfg = setup_cfg(args)

    cfg.defrost()
    cfg.MODEL.BACKBONE.PRETRAIN = False
    if cfg.MODEL.HEADS.POOL_LAYER == 'FastGlobalAvgPool':
        cfg.MODEL.HEADS.POOL_LAYER = 'GlobalAvgPool'
    model = build_model(cfg)  #!!重要
    Checkpointer(model).load(cfg.MODEL.WEIGHTS)  #!!重要
    if hasattr(model.backbone, 'deploy'):
        model.backbone.deploy(True)
    model.eval()

    inputs = torch.randn(args.batch_size, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(model.device)
    traced_script_module = torch.jit.trace(model, inputs) #!!重要
    traced_script_module.save("script.pt")

2、模型调用
使用相同或更高版本的libtorch,加载上一步骤生成的.pt模型。


    try {
        module_ = torch::jit::load(“xx.pt”);
    }
    catch (const c10::Error& e) {
        std::cerr << "Error loading the model!\n";
        std::exit(EXIT_FAILURE);
    }

二、思路2 pytorch环境模型转libtorch环境模型
将pytorch下.pt模型先转通用的.onnx模型,再使用tensorrt等加速工具转.engine模型**(注:不同平台下的加速工具不同,例如Nvidia家tensorRT、Rockchip家RKNN)

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(1)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2022年5月18日
下一篇 2022年5月18日

相关推荐