CCTrans: Simplifying and Improving Crowd Counting with Transformer
论文地址:https://arxiv.org/pdf/2109.14483v1.pdf
代码:https://github.com/wfs123456/CCTrans
该论文使用的Twins-large模型权重链接已上传百度云:
链接:https://pan.baidu.com/s/1QSTA7P-81_uHD6wqzk93Og
提取码:8wd4
感谢王福森大佬的复现以及耐心解答。
一、论文解读
整个模型的backbone是Twins,如果不了解可以看我另外一篇文章。Twins: Revisiting the Design of Spatial Attention inVision Transformers解读_算法卷死我算了的博客-CSDN博客
首先是一张图片打成patch后放入Twins,王大佬复现的代码里面用的是Twins-SVT-large,这个模型的效果最好。然后会有四个阶段的输出。,取后面三个阶段的输出分别上采样到原图片的1/8。
1.Pyramid Feature Aggregation
文中所提Pyramid Feature Aggregation就是图中F1,F2,F3.直接用卷积三个阶段的输出,然后将所得concat起来,最后得到的输出放入下一个阶段:Regression Head with Multi-scale Receptive Fields。
2.Regression Head with Multi-scale Receptive Fields:
文中的意思是Pyramid Feature Aggregation得到的输出复制三份,C1这一路是卷积+空洞率=1的卷积,C2是卷积+空洞率=2的卷积,C3这一路是卷积+空洞率=3的卷积。设置不同的空洞率来解决网格问题(相同的空洞率会丢失一些细节)。
复现的论文里面只用了空洞卷积,没有在前面加一层卷积。问题应该不大,可能掉点也可能升,有兴趣的可以测试一下。
C1+C2+C3三路输出concat后与1×1捷径所得加起来。最后一个BN一个ReLU得到最终结果。
然后采用了两种监督方法,第一种全监督,输出density map。第二种弱监督,输出人头数。
3.LOSS
全监督的loss用的是DM-count(Distribution Matching for Crowd Counting)里面提出的loss。
L1是计数loss,OT是Optimal Transport loss,TV是Total Variation loss。代码里面用的是这几个loss,文章中改进了一下把TV loss换成了L2 loss(均方误差)。
弱监督的loss就简单的一个Smooth L1 loss
4.结果
二、代码踩坑记录
1.timm==0.4.12
个人感觉timm库换个版本经常有各种问题,一开始我的服务器上timm是0.3几的,然后代码就会报错,所以严格使用0.4.12的timm库
2.数据集问题
我用的是上海A,网上下载的ground_truth文件夹是下横杠,而代码里是直接写死的ground-truth,直接运行会报找不到mat文件的错误。在无数次检查路径甚至实在找不到后问了作者也没找到。最后突然看见了这个横杠的区别。
修改方法1:把train和test里面的ground_truth改为ground-truth
修改方法2:代码datasets文件夹中crowd.py184行代码修改成
def __getitem__(self, item):
img_path = self.im_list[item]
name = os.path.basename(img_path).split('.')[0]
gd_path = os.path.join(self.root_path, 'ground_truth', 'GT_{}.mat'.format(name)) # 把ground-truth改为ground_truth
3.实验设备问题
我的服务器是3070+linux,batch-size改成8,max-epoch不用4000,差不多1500和代码作者一样就可以取得好的效果。据我导师说batch-size对精度还是有一定影响的。所以我复现下来batch_size=8 单卡3070 测试mae55.6,mse97.16
最新结果:batch_size=24 单卡,测试mae54.09,mse92.87 epoch2000
epoch1500的mae54.09,mse99.01 。
4.test找不到model_path问题
测试的时候,需要给定model_path和data_path的路径,作者没有写默认值,自己加就行。还有就是把parser.add_argument('–model-path', type=str, required=True, 这里的required=True去掉,不然一直会显示找不到model-path的路径。
源代码有个很严重的问题,经常使用-,所以在argument的代码里面我把所有的-都改成了_。test修改过后的代码如图
parser = argparse.ArgumentParser(description='Test ')
parser.add_argument('--device', default='0', help='assign device')
parser.add_argument('--batch_size', type=int, default=8,
help='train batch size')
parser.add_argument('--crop_size', type=int, default=256,
help='the crop size of the train image')
parser.add_argument('--model_path', default='C:/Users/Desktop/CCTrans-main/best.pth', type=str,
help='saved model path')
parser.add_argument('--data_path', type=str, default='C:/Users/Desktop/CCTrans-main/shanghaitech/part_A_final',
help='dataset path')
parser.add_argument('--dataset', type=str, default='sha',
help='dataset name: qnrf, nwpu, sha, shb, custom')
parser.add_argument('--pred-density-map-path', type=str, default='inference_results',
help='save predicted density maps when pred-density-map-path is not empty.')
5.test报错:fixture ‘xxx’ not found
这个bug挺奇怪的,一开始我运行test.py都没有报过这个错误。昨天报了这个错误。查了一下,使用了以下几个方法。
1.由于pycharm中以pytest运行,它会默认把test、test_开头的.py文件当做单元测试,所以需要修改文件名。将test_image_patch.py.py 修改为image_patch.py ,再次运行,我没有成功。
2.因为程序是在是在测试环境下运行的, 只要将环境改成 Unittests 。File–> Settings –> Tools –> Python Integrated Tools –> Default test runner , 将pytest改为Unittests即可。
方法二成功了。反正很奇怪这个bug,只有我最近用的那台服务器报这个错。其他服务器正常。仅记录作为参考。
文章出处登录后可见!