优化器 zero_grad() 和 optimize() – 每个样本分成补丁时如何使用?
原文标题 :Optimizer zero_grad() and optimize() – how to use when each sample split into patches?
我有一个由小型 RGB 图像组成的数据集。然后将每个图像拆分为特定数量的补丁,然后调整每个补丁的大小和模糊(高斯)。我的模型的输入(参见使用 CNN 的热图像增强,用于提高分辨率和处理热图像中的模糊问题的浅层 3 层网络)是调整大小 + 模糊补丁,而预期结果只是调整大小的补丁。
网络相当简单:
class TEN_Network(nn.Module):
def __init__(self) -> None:
super(TEN_Network, self).__init__()
self.model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1, 64, 7, stride=1, padding=3)),
('relu1', nn.ReLU(True)),
('conv2', nn.Conv2d(64, 32, 5, stride=1, padding=2)),
('relu2', nn.ReLU(True)),
('conv3', nn.Conv2d(32, 32, 3, stride=1, padding=1)),
('relu3', nn.ReLU(True)),
('conv4', nn.Conv2d(32, 1, 3, stride=1, padding=1))
]))
def forward(self, x):
x = self.model(x)
return x
我选择这个 CNN 有很多原因,简单就是其中之一。由于我对神经网络和 PyTorch 还很陌生,我认为它会提供一个很好的游乐场。
我的问题是这个 – 假设每个样本都被分成补丁,我应该将每个补丁的梯度(以及分别step()
优化器)归零还是应该计算平均损失(样本中所有补丁的所有损失的总和除以数量补丁)或者我应该在训练我的样本开始和处理补丁之后运行这两个步骤(导致上述每个样本的平均损失)?
目前我有以下(伪代码,优化器是亚当,损失是 MSE):
# ...
for epoch_id in range(0, epochs_total):
# Load dataset with dataloader
dataloader = DataLoader(dataset=custom_dataset, ...)
# Train
for sample_expected, sample_input in next(iter(dataloader)):
loss_sample_avg = 0.0
patches_count = len(sample_expected)
# optimizer.zero_grad() <---- HERE(1)
for patch_expected, patch_input in zip(sample_expected, sample_input):
# optimizer.zero_grad() <---- HERE(2)
# ...
patch_predicted = model(patch_input)
loss = citerion(patch_predicted, patch_input)
loss_sample_avg += loss.item()
loss.backward()
# optimizer.step() <---- HERE(2)?
# optimizer.step() <---- HERE(1)
# Validate
...
这可能是我没有得到一些非常基本的东西。我知道
- 优化器 step() 应该总是跟在 zero_grad() 之后
- 只有在处理完另一个 forward() 传递之后,才可能调用backward()
由于该论文没有提供任何代码(尝试联系作者,但不出所料),我试图弄清楚如何自己实现它。
回复
我来回复-
DerekG 评论
您应该将所有补丁作为一个批次传递。在一个单独的补丁上执行梯度步骤构成纯随机梯度下降,这通常不是优选的,因为它会产生对所需梯度的非常嘈杂的估计。此外,循环遍历图像中的每个补丁在计算上效率低下。
至少,将一张图像中的所有补丁批量处理在一起。不过,理想情况下,您可以将来自多个图像的多个补丁作为一个批次传递。所以
patch_input
和patch_expected
应该是[batch_size x color_channels x patch_height x path_width]
大小。batch_size
本身应该是每个图像的采样补丁数量乘以批次中的图像数量。很可能,除了添加几行来将现有 for 中的输入和目标整理到一个单独的张量中之外,这将需要对您的代码进行很少的修改。我会提供细节,但目前尚不清楚您的目标和输入采用什么形式。
2年前