使用 Pytorch 模型参数进行计算时出现结果类型转换错误

青葱年少 pytorch 215

原文标题Result type cast error when doing calculations with Pytorch model parameters

当我运行下面的代码时:

import torchvision

model = torchvision.models.densenet201(num_classes=10)
params = model.state_dict()
for var in params:
    params[var] *= 0.1

报告了一个 RuntimeError:

RuntimeError: result type Float can't be cast to the desired output type Long

但是当我将 params[var] *= 0.1 更改为 params[var] = params[var] * 0.1 时,错误消失了。

为什么会发生这种情况?

我以为 params[var] *= 0.1 和 params[var] = params[var] * 0.1 效果一样。

原文链接:https://stackoverflow.com//questions/71427796/result-type-cast-error-when-doing-calculations-with-pytorch-model-parameters

回复

我来回复
  • Phoenix的头像
    Phoenix 评论

    首先,让我们知道densenet201中的第一个long型参数,如果模型中有BatchNormalization层,你会发现features.norm0.num_batches_tracked表示训练期间用于计算均值和方差的mini-batch的数量。This parameter is a long-type number and cannot be float type because it behaves like a counter

    其次,在 PyTorch 中,有两种类型的操作:

    • 非就地操作:您将计算后的新输出分配给变量的新副本,例如x = x + 1 或 x = x / 2。赋值前 x 的内存位置不等于赋值后的内存位置,因为您有原始变量的副本。
    • 就地操作:当计算直接应用于变量的原始副本而不在此处进行任何复制时,例如x += 1 或 x /= 2。

    让我们转到您的示例以了解发生了什么:

    1. 非 Inplcae 操作:model = torchvision.models.densenet201(num_classes=10)params = model.state_dict()name = ‘features.norm0.num_batches_tracked’print(id(params[name])) # 140247785908560params[name] = params [name] + 0.1print(id(params[name])) # 140247785908368 print(params[name].type()) # 改为torch.FloatTensor
    2. 就地操作: print(id(params[name])) # 140247785908560params[name] += 1print(id(params[name])) # 140247785908560 print(params[name].type()) # 还是 torch.LongTensorparams[name ] += 0.1 # 你想把原来的拷贝类型改成float,你得到了一个错误

    最后,几点说明:

    • 就地操作可以节省一些内存,但在计算导数时可能会出现问题,因为会立即丢失历史记录。因此,不鼓励使用它们。资源
    • 当您决定使用就地操作时应该谨慎,因为它们会覆盖原始内容。
    • 如果你使用 pandas,这有点类似于 pandas 中的 inplace=True :)。

    这是一个很好的资源,可以阅读更多关于就地操作源的信息,也可以阅读这个讨论源。

    2年前 0条评论