在pytorch中调用tensor.backward()时出错可能是由grad_fn=引起的
pytorch 707
原文标题 :An error when calling tensor.backward() in pytorch may caused by grad_fn=
我在 pytorch backward() 函数中遇到问题,这是我的代码
c = torch.empty(4)
a = torch.tensor(2.,requires_grad=True)
b = torch.tensor(3.,requires_grad=True)
c[0] = a*2
#c[0]:tensor(4., grad_fn=<SelectBackward>)
#c:tensor([4.0000e+00, 3.1720e+00, 1.0469e-38, 9.2755e-39], grad_fn=<CopySlices>)
c[0].backward()
这样就可以了,我可以得到正确的答案 a.grad==tensor(2.),但是如果我在上面的代码之后排除以下代码:
c[1] = b*2
c[1].backward()
这将导致以下错误:
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
但是如果我在第一部分写c[0].backward(retain_grad=True)
,它不会导致错误。
有人可以告诉我在倒退期间释放了哪些中间结果。 CopySlices 有问题吗?
太感谢了!
回复
我来回复-
noob 评论
为了减少内存使用,在
.backward()
调用期间,所有中间结果在不再需要时被删除。因此,如果您再次尝试调用.backward()
,则中间结果不存在并且无法执行反向传递(并且您会收到您看到的错误)。您可以调用.backward(retain_graph=True)
进行不会删除中间结果的反向传递,并且这样您就可以再次呼叫.backward()
。除了最后一次调用backward之外的所有调用都应该有retain_graph=True
选项。c[0] = a*2 #c[0]:tensor(4., grad_fn=<SelectBackward>) #c:tensor([4.0000e+00, 3.1720e+00, 1.0469e-38, 9.2755e-39], grad_fn=<CopySlices>) c[0].backward(retain_graph=True)
c[1] = b*2 c[1].backward(retain_graph=True) ```
2年前