使用 pytorch 绘制 sin(x) 的导数
pytorch 462
原文标题 :Plot derivatives of sin(x) using pytorch
我不确定为什么我的代码没有绘制 cos(x) (是的,我知道 pytorch 有 cos(x) 函数)
import math
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
x = torch.linspace(-math.pi, math.pi, 5000, requires_grad=True)
y = torch.sin(x)
y.backward(x)
x.grad == torch.cos(x) # assert x.grad same as cos(x)
plt.plot(x.detach().numpy(), y.detach().numpy(), label='sin(x)')
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='cos(x)') # print derivative of sin(x)
回复
我来回复-
hkchengrex 评论
该回答已被采纳!
您需要将上游梯度(在您的情况下等于所有梯度)而不是
x
作为输入到y.backward()
。因此
import math import torch import matplotlib.pyplot as plt x = torch.linspace(-math.pi, math.pi, 5000, requires_grad=True) y = torch.sin(x) y.backward(torch.ones_like(x)) plt.plot(x.detach().numpy(), y.detach().numpy(), label='sin(x)') plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='cos(x)') # print derivative of sin(x) plt.show()
2年前