在pytorch中调用tensor.backward()时出错可能是由grad_fn=引起的

xiaoxingxing pytorch 707

原文标题An error when calling tensor.backward() in pytorch may caused by grad_fn=

我在 pytorch backward() 函数中遇到问题,这是我的代码

c = torch.empty(4)

a = torch.tensor(2.,requires_grad=True)

b = torch.tensor(3.,requires_grad=True)

c[0] = a*2

#c[0]:tensor(4., grad_fn=<SelectBackward>)

#c:tensor([4.0000e+00, 3.1720e+00, 1.0469e-38, 9.2755e-39], grad_fn=<CopySlices>)

c[0].backward()

这样就可以了,我可以得到正确的答案 a.grad==tensor(2.),但是如果我在上面的代码之后排除以下代码:

c[1] = b*2

c[1].backward()

这将导致以下错误:

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

但是如果我在第一部分写c[0].backward(retain_grad=True),它不会导致错误。

有人可以告诉我在倒退期间释放了哪些中间结果。 CopySlices 有问题吗?

太感谢了!

原文链接:https://stackoverflow.com//questions/71408271/an-error-when-calling-tensor-backward-in-pytorch-may-caused-by-grad-fn-copysl

回复

我来回复
  • noob的头像
    noob 评论

    为了减少内存使用,在.backward()调用期间,所有中间结果在不再需要时被删除。因此,如果您再次尝试调用.backward(),则中间结果不存在并且无法执行反向传递(并且您会收到您看到的错误)。您可以调用.backward(retain_graph=True)进行不会删除中间结果的反向传递,并且这样您就可以再次呼叫.backward()。除了最后一次调用backward之外的所有调用都应该有retain_graph=True选项。

    c[0] = a*2
    
    #c[0]:tensor(4., grad_fn=<SelectBackward>)
    
    #c:tensor([4.0000e+00, 3.1720e+00, 1.0469e-38, 9.2755e-39], grad_fn=<CopySlices>)
    
    c[0].backward(retain_graph=True)
    
    
    c[1] = b*2
    
    c[1].backward(retain_graph=True)
    ```
    
    2年前 0条评论