“NoneType”对象在 Pytorch 模型测试部分不可迭代?
pytorch 263
原文标题 :‘NoneType’ object is not iterable in Pytorch model testing part?
我正在使用 Pytorch 模型研究高光谱图像。当我想测试模型时,我得到了“NoneType objectis not iterable”错误。图像上有一个窗口滑动,我将其添加到测试部分。在 for 循环中,批处理值变为无,我不知道如何修复它。我附上下面的工作代码。谢谢你的帮助。
def test(net, img, hyperparams):
net.eval()
patch_size = hyperparams["patch_size"]
center_pixel = hyperparams["center_pixel"]
batch_size, device = hyperparams["batch_size"], hyperparams["device"]
n_classes = hyperparams["n_classes"]
kwargs = {
"step": hyperparams["test_stride"],
"window_size": (patch_size, patch_size),
}
probs = np.zeros(img.shape[:2] + (n_classes,))
iterations = count_sliding_window(img, **kwargs) // batch_size
for batch in tqdm(
grouper(batch_size, sliding_window(img, **kwargs)),
total=(iterations),
desc="Inference on the image",
):
with torch.no_grad():
if patch_size == 1:
data = [b[0][0, 0] for b in batch]
data = np.copy(data)
data = torch.from_numpy(data)
else:
data = [b[0] for b in batch]
data = np.copy(data)
data = data.transpose(0, 3, 1, 2)
data = torch.from_numpy(data)
data = data.unsqueeze(1)
indices = [b[1:] for b in batch]
data = data.to(device)
output = net(data)
if isinstance(output, tuple):
output = output[0]
output = output.to("cpu")
if patch_size == 1 or center_pixel:
output = output.numpy()
else:
output = np.transpose(output.numpy(), (0, 2, 3, 1))
for (x, y, w, h), out in zip(indices, output):
if center_pixel:
probs[x + w // 2, y + h // 2] += out
else:
probs[x : x + w, y : y + h] += out
return probs
probabilities = test(model, img, hyperparams)
然后我得到了这个错误:
Inference on the image: 0%| | 0/210 [00:00<?, ?it/s]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-55-cabf8b6fa06d> in <module>()
----> 1 probabilities = test(model, img, hyperparams)
2 prediction = np.argmax(probabilities, axis=-1)
<ipython-input-54-b248a53c93ce> in test(net, img, hyperparams)
23 with torch.no_grad():
24 if patch_size == 1:
---> 25 data = [b[0][0, 0] for b in batch]
26 data = np.copy(data)
27 data = torch.from_numpy(data)