使用 Pytorch 模型参数进行计算时出现结果类型转换错误

青葱年少 pytorch 252

原文标题Result type cast error when doing calculations with Pytorch model parameters

当我运行下面的代码时:

import torchvision

model = torchvision.models.densenet201(num_classes=10)
params = model.state_dict()
for var in params:
    params[var] *= 0.1

报告了一个 RuntimeError:

RuntimeError: result type Float can't be cast to the desired output type Long

但是当我将 params[var] *= 0.1 更改为 params[var] = params[var] * 0.1 时,错误消失了。

为什么会发生这种情况?

我以为 params[var] *= 0.1 和 params[var] = params[var] * 0.1 效果一样。

原文链接:https://stackoverflow.com//questions/71427796/result-type-cast-error-when-doing-calculations-with-pytorch-model-parameters

回复

我来回复
  • 暂无回复内容