CCTrans: Simplifying and Improving Crowd Counting with Transformer解读和代码踩坑记录

CCTrans: Simplifying and Improving Crowd Counting with Transformer

论文地址:https://arxiv.org/pdf/2109.14483v1.pdf

代码:https://github.com/wfs123456/CCTrans

该论文使用的Twins-large模型权重链接已上传百度云:

链接:https://pan.baidu.com/s/1QSTA7P-81_uHD6wqzk93Og 
提取码:8wd4

感谢王福森大佬的复现以及耐心解答。

一、论文解读

整个模型的backbone是Twins,如果不了解可以看我另外一篇文章。Twins: Revisiting the Design of Spatial Attention inVision Transformers解读_算法卷死我算了的博客-CSDN博客

CCTrans: Simplifying and Improving Crowd Counting with Transformer解读和代码踩坑记录首先是一张图片打成patch后放入Twins,王大佬复现的代码里面用的是Twins-SVT-large,这个模型的效果最好。然后会有四个阶段的输出。,取后面三个阶段的输出分别上采样到原图片的1/8。

1.Pyramid Feature Aggregation

        文中所提Pyramid Feature Aggregation就是图中F1,F2,F3.直接用卷积三个阶段的输出,然后将所得concat起来,最后得到的输出放入下一个阶段:Regression Head with Multi-scale Receptive Fields。

2.Regression Head with Multi-scale Receptive Fields:

        文中的意思是Pyramid Feature Aggregation得到的输出复制三份,C1这一路是卷积+空洞率=1的卷积,C2是卷积+空洞率=2的卷积,C3这一路是卷积+空洞率=3的卷积。设置不同的空洞率来解决网格问题(相同的空洞率会丢失一些细节)。

        复现的论文里面只用了空洞卷积,没有在前面加一层卷积。问题应该不大,可能掉点也可能升,有兴趣的可以测试一下。

        C1+C2+C3三路输出concat后与1×1捷径所得加起来。最后一个BN一个ReLU得到最终结果。

然后采用了两种监督方法,第一种全监督,输出density map。第二种弱监督,输出人头数。

 3.LOSS

全监督的loss用的是DM-count(Distribution Matching for Crowd Counting)里面提出的loss。

CCTrans: Simplifying and Improving Crowd Counting with Transformer解读和代码踩坑记录

 L1是计数loss,OT是Optimal Transport loss,TV是Total Variation loss。代码里面用的是这几个loss,文章中改进了一下把TV loss换成了L2 loss(均方误差)。

 CCTrans: Simplifying and Improving Crowd Counting with Transformer解读和代码踩坑记录

 弱监督的loss就简单的一个Smooth L1 loss

CCTrans: Simplifying and Improving Crowd Counting with Transformer解读和代码踩坑记录

4.结果

CCTrans: Simplifying and Improving Crowd Counting with Transformer解读和代码踩坑记录

 二、代码踩坑记录

1.timm==0.4.12

个人感觉timm库换个版本经常有各种问题,一开始我的服务器上timm是0.3几的,然后代码就会报错,所以严格使用0.4.12的timm库

2.数据集问题

我用的是上海A,网上下载的ground_truth文件夹是下横杠,而代码里是直接写死的ground-truth,直接运行会报找不到mat文件的错误。在无数次检查路径甚至实在找不到后问了作者也没找到。最后突然看见了这个横杠的区别。

修改方法1:把train和test里面的ground_truth改为ground-truth

CCTrans: Simplifying and Improving Crowd Counting with Transformer解读和代码踩坑记录

 修改方法2:代码datasets文件夹中crowd.py184行代码修改成

  def __getitem__(self, item):
        img_path = self.im_list[item]
        name = os.path.basename(img_path).split('.')[0]
        gd_path = os.path.join(self.root_path, 'ground_truth', 'GT_{}.mat'.format(name)) # 把ground-truth改为ground_truth

3.实验设备问题

我的服务器是3070+linux,batch-size改成8,max-epoch不用4000,差不多1500和代码作者一样就可以取得好的效果。据我导师说batch-size对精度还是有一定影响的。所以我复现下来batch_size=8 单卡3070 测试mae55.6,mse97.16

最新结果:batch_size=24 单卡,测试mae54.09,mse92.87  epoch2000

epoch1500的mae54.09,mse99.01 。

4.test找不到model_path问题

测试的时候,需要给定model_path和data_path的路径,作者没有写默认值,自己加就行。还有就是把parser.add_argument('–model-path', type=str, required=True, 这里的required=True去掉,不然一直会显示找不到model-path的路径。

源代码有个很严重的问题,经常使用-,所以在argument的代码里面我把所有的-都改成了_。test修改过后的代码如图

parser = argparse.ArgumentParser(description='Test ')
parser.add_argument('--device', default='0', help='assign device')
parser.add_argument('--batch_size', type=int, default=8,
                        help='train batch size')
parser.add_argument('--crop_size', type=int, default=256,
                    help='the crop size of the train image')
parser.add_argument('--model_path', default='C:/Users/Desktop/CCTrans-main/best.pth', type=str,
                    help='saved model path')
parser.add_argument('--data_path', type=str, default='C:/Users/Desktop/CCTrans-main/shanghaitech/part_A_final',
                    help='dataset path')
parser.add_argument('--dataset', type=str, default='sha',
                    help='dataset name: qnrf, nwpu, sha, shb, custom')
parser.add_argument('--pred-density-map-path', type=str, default='inference_results',
                    help='save predicted density maps when pred-density-map-path is not empty.')

 5.test报错:fixture ‘xxx’ not found

这个bug挺奇怪的,一开始我运行test.py都没有报过这个错误。昨天报了这个错误。查了一下,使用了以下几个方法。

1.由于pycharm中以pytest运行,它会默认把test、test_开头的.py文件当做单元测试,所以需要修改文件名。将test_image_patch.py.py 修改为image_patch.py ,再次运行,我没有成功。

2.因为程序是在是在测试环境下运行的, 只要将环境改成 Unittests 。File–> Settings –> Tools –> Python Integrated Tools –> Default test runner , 将pytest改为Unittests即可。

方法二成功了。反正很奇怪这个bug,只有我最近用的那台服务器报这个错。其他服务器正常。仅记录作为参考。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年3月4日 下午4:17
下一篇 2023年3月4日 下午4:18

相关推荐