【CV学习笔记】之tensorrt篇之Vision transformer

1、摘要

本次学习内容主要学习了vision transformer的网络结构,并在cpp和python中实现了后处理代码(其实没啥后处理的,取最大值即可),同时加强了对transformer原理的理解,主要是为了学习detr等模型做铺垫。

原vit学习地址:Vision Transformer详解

个人学习代码仓库地址: https://github.com/Rex-LK/tensorrt_learning

2、vit简介

Transformer 原本是在nlp领域提出的模型,后来发现在图像处理领域也有很好的效果,并且transformer日益成为图像处理中的主干网络,甚至在不久的将来,有替换掉cnn的可能,于是,学习transformer是十分必要的。

transformer的主要结构都是围绕了qkv三个向量来运作的,其中

Q为query,可认为是查询向量,当前向量与其他向量计算注意力

K为key,可认为是被查向量,其他向量与当前向量计算注意力

V为注意力计算的结果

其主要过程可利用一个公式表示

【CV学习笔记】之tensorrt篇之Vision transformer

其中【CV学习笔记】之tensorrt篇之Vision transformer为向量长度,用于减小不同输入向量之间的差距。

vit是基于transformer的基础上,将图像均分为多个patch,然后再将每个patch作为一个token输入到模型中,计算每个patch之间的注意力,引用原论文中的一张图片。

【CV学习笔记】之tensorrt篇之Vision transformer

上图很好的反应了vit的结构,其中每个patch的位置信息是独一无二的,如果不加上Position Embedding的话,打乱patch的顺序是不改变注意力的结果的,这显然不适用于图像处理,transformer中使用了固定的位置编码,而在vit中使用了可以训练的位置编码。

在经过一些列的Transformer Encode块之后堆叠之后,输入维度为[197,768],输出维度也为[197,768],然后采用MLP直接进行分类,下图是原文作者绘制的网络结构图,简直是太清晰了,不愧是B站导师,由此,vit模型的网络结构图已经学习的差不多了,更加详细的代码和讲解可以参考讲解视频,现在就是要将这个模型在tensorrt中跑起来。

【CV学习笔记】之tensorrt篇之Vision transformer

3、tensorrt加速

3.1、pytorch2onnx

原模型可以在https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth进行下载,也可以到百度网盘上进行下载,下载完成后,进入到demo/vit/文件夹中,然后原项目中的数据集进行训练,可以得到自己训练的分类模型,然后导出onnx并使用onnx_simplify

python export_onnx.py
python onnx_simplify.py

利用netron查看生成的vit.onnx模型,这里主要查看最下面output是否有问题。导出onnx后,可以利用predict.py和infer-onnxruntime.py进行推理来验证导出onnx的正确性

在torch下模型的输出为

tensor([0.8261, 0.0730, 0.0128, 0.0589, 0.0292])

onnxruntime下的输出为

[0.8260881  0.0730136  0.01283037 0.05891288 0.02915508]

3.2、python -tensorrt加速

由于分类模型没有后处理,直接获得模型输出中的最大值即可。

3.3、cpp-tensorrt加速

其构造engine的模型与其他模型一致,后处理只需要一行即可搞定

int predict_label = std::max_element(prob, prob + num_classes) - prob;

4、总结

本次学习笔记学习了transformer的基本结构,并引申到了图像领域,对vision transformer的基本网络结构有了大致的了解,本次学习内容是为了为之后的detr的学习做铺垫。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
上一篇 2022年6月13日 下午12:27
下一篇 2022年6月13日 下午12:30

相关推荐

本站注重文章个人版权,不会主动收集付费或者带有商业版权的文章,如果出现侵权情况只可能是作者后期更改了版权声明,如果出现这种情况请主动联系我们,我们看到会在第一时间删除!本站专注于人工智能高质量优质文章收集,方便各位学者快速找到学习资源,本站收集的文章都会附上文章出处,如果不愿意分享到本平台,我们会第一时间删除!