使用 pytorch 绘制 sin(x) 的导数

社会演员多 pytorch 411

原文标题Plot derivatives of sin(x) using pytorch

我不确定为什么我的代码没有绘制 cos(x) (是的,我知道 pytorch 有 cos(x) 函数)

import math
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import random

x = torch.linspace(-math.pi, math.pi, 5000, requires_grad=True)
y = torch.sin(x)
y.backward(x)
x.grad == torch.cos(x) # assert x.grad same as cos(x)
plt.plot(x.detach().numpy(), y.detach().numpy(), label='sin(x)')
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='cos(x)') # print derivative of sin(x)

原文链接:https://stackoverflow.com//questions/71666770/plot-derivatives-of-sinx-using-pytorch

回复

我来回复
  • hkchengrex的头像
    hkchengrex 评论

    您需要将上游梯度(在您的情况下等于所有梯度)而不是x作为输入到y.backward()

    因此

    import math
    import torch
    import matplotlib.pyplot as plt
    
    x = torch.linspace(-math.pi, math.pi, 5000, requires_grad=True)
    y = torch.sin(x)
    y.backward(torch.ones_like(x))
    plt.plot(x.detach().numpy(), y.detach().numpy(), label='sin(x)')
    plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='cos(x)') # print derivative of sin(x)
    plt.show()
    
    2年前 0条评论