3.3 线性回归的简洁实现 New Version
下面介绍如何使用pytorch更方便地实现线性回归地训练
3.3.1 生成数据集
%matplotlib inline
# 设置成嵌入显示
import torch
import random
from torch.utils import data
import d2l
#from d2l import torch as d2l
true_w = torch.tensor([2,-3.4])
true_b = 4.2
features,labels = d2l.synthetic_data(true_w,true_b,1000)
与上一节一样,生成数据集
3.3.2 读取数据
读取数据则采用data
包读取
batch_size = 10
data_iter = d2l.load_array((features,labels),batch_size)
next(iter(data_iter))
[tensor([[ 1.7436, -0.5150],
[-0.2934, 0.7571],
[-0.4389, -1.1126],
[ 0.2263, -0.4129],
[ 0.2301, -2.0866],
[ 1.2678, 0.5725],
[-0.4783, 0.7032],
[ 1.4662, 0.0666],
[ 0.0029, 0.4886],
[ 0.1611, 2.0225]]),
tensor([[ 9.4443],
[ 1.0407],
[ 7.0907],
[ 6.0675],
[11.7471],
[ 4.7773],
[ 0.8702],
[ 6.8900],
[ 2.5250],
[-2.3418]])]
这里与上一节一样,因此可以读取数据
3.3.3 定义模型
下面定义模型 torch.nn 中的nn是 neural networks的缩写。nn的核心数据结构是Module,既可以表示神经网络的某个层,也可以表示一个包含很多层的神经网络。
一个实例应该包含一些层和返回输出的前向传播方法。
下面看如何用nn.Module
#import torch.nn as nn
from torch import nn
net = nn.Sequential(nn.Linear(2,1))
3.3.4初始化模型参数
net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)
tensor([0.])
可以看到两种方法还是有一定的区别。名字上的区别?
3.3.5 定义损失函数
可以看成是一种特殊的层
loss = nn.MSELoss()
3.3.6 定义优化算法
SGD,Adam,RMSProp等,下面创建一个优化算法的实例
trainer = torch.optim.SGD(net.parameters(),lr = 0.03)
也可以为不同的子网络设置不同的学习率,在finetune时经常用到
#optimizer = optim.SGD([
# {'params':net.subnet1.parameters()}, #lr = 0.03
# {'params':net.subnet2.parameters(),'lr': 0.01}
#],lr = 0.03)
也可以调整学习率,或新建学习率
#for param_group in optimizer.param_groups:
# param_group['lr'] *=0.1
3.3.7 训练模型
num_epochs = 3
for epoch in range(num_epochs):
for X,y in data_iter:
l = loss(net(X),y)
trainer.zero_grad()
l.backward()
trainer.step()
#for param_group in optimizer.param_groups:
# param_group['lr'] *=0.1
l = loss(net(features),labels)
print(f'epoch{epoch+1}, loss {l:f}')
epoch1, loss 0.000227
epoch2, loss 0.000097
epoch3, loss 0.000098
下面访问net的层,然后读取权重和偏置,进而与真实值进行比较
w = net[0].weight.data
print('w的误差:',true_w - w.reshape(true_w.shape))
w的误差: tensor([-0.0003, -0.0009])
b = net[0].bias.data
print('b的误差:',true_b - b)
b的误差: tensor([-0.0003])
文章出处登录后可见!
已经登录?立即刷新