使用 Pytorch 模型参数进行计算时出现结果类型转换错误
pytorch 252
原文标题 :Result type cast error when doing calculations with Pytorch model parameters
当我运行下面的代码时:
import torchvision
model = torchvision.models.densenet201(num_classes=10)
params = model.state_dict()
for var in params:
params[var] *= 0.1
报告了一个 RuntimeError:
RuntimeError: result type Float can't be cast to the desired output type Long
但是当我将 params[var] *= 0.1 更改为 params[var] = params[var] * 0.1 时,错误消失了。
为什么会发生这种情况?
我以为 params[var] *= 0.1 和 params[var] = params[var] * 0.1 效果一样。