一、检测问题解决
今天在用detect.py以及自己训练的模型检测图片时,遇到了一些问题,在这总结一下子
1、修改代码
看下面这段代码,是yolo.py中的Model类中的forward函数,训练的时候augment=False
,而检测时我们得把它改为true,augment=True
def forward(self, x, augment=True, profile=False): # 训练:augment=False
if augment: # TTA(Test Time Augmentation) 只在测试和检测的时候做的数据增强
return self.forward_augment(x) # augmented inference, None
else: # 如果是训练的过程则执行else
return self.forward_once(x, profile) # single-scale inference, train
2、还有一个特别值得注意的问题是,当你使用你自己的训练模型检测图片时,你的网络模型得用的是和训练时一模一样的,不能修改,不然会报错,因为在检测时,也会执行yolo.py
和detect.py
文件
例如,我训练模型时用的concat
模块是这样的:(common.py文件中)
这是我把原来的concat模块给改了
# biFPN 1
class Concat(nn.Module):
def __init__(self, c1):
super(Concat, self).__init__()
# self.relu = nn.ReLU()
self.w = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.epsilon = 0.0001
self.swish = Swish()
def forward(self, x):
weight = self.w / (torch.sum(self.w, dim=0) + self.epsilon)
# Connections for P6_0 and P7_0 to P6_1 respectively
x = self.swish(weight[0] * x[0] + weight[1] * x[1])
return x
然而我检测时又用了原来的concat
模块:
# 原
class Concat(nn.Module):
# Concatenate a list of tensors along dimension
def __init__(self, dimension=1):
super(Concat, self).__init__()
self.d = dimension
def forward(self, x):
return torch.cat(x, self.d)
这样的话当然是不行的哈
二、检测疑惑解答
今天,在检测过程中发现出现了网络层数和参数的打印信息,我很疑惑为什么会有这种信息,检测不就是直接拿训练好的模型去检测吗,为什么会有网络信息??
原来训练好的.pt
文件只是一组参数,用它检测图片时还得把它加入到网络模型中,用网络模型去处理检测图片,这也是为什么第一个问题中检测模型要和训练模型完全一致的原因。就好比你拿狗训练的参数要套在猫身上肯定是对不上的咯!
文章出处登录后可见!
已经登录?立即刷新