【作业向】
根据给定的猫狗分类数据集,对比 单层CNN模型、从头训练CNN模型(mobileNet)、微调预训练CNN模型(mobileNet)的差异。生成的模型的正向传播图(相关方法见我)。
使用PyTorch实现。
本文代码(数据集在同目录下):我的Github
文章目录
- 关于数据集
- 建立Dataset对象
- 模型1:单层卷积+单层池化+全连接
- 定义训练和评估函数
- 模型2:从头训练(MobileNet)
- 模型3:预训练模型+微调(MobileNet)
- 保存模型
- 前向传播可视化
- 测试集评估模型效果
关于数据集
数据集结构很简单,训练集和测试集分两个目录,分别对应cat和dog两个文件夹。
只需要使用torchvision.datasets下的ImageFolder方法即可。
建立Dataset对象
包括建立Dataset对象、训练集和验证集的DataLoader对象。
模型1:单层卷积+单层池化+全连接
定义训练和评估函数
# 定义train函数,使用GPU训练并评价模型
import time
# 测试集上评估准确率
def evaluate_accuracy(data_iter, net, device=None):
"""评估模型预测正确率"""
if device is None and isinstance(net, torch.nn.Module):
# 如果没指定device就用net的device
device = list(net.parameters())[0].device
acc_sum, n = 0.0, 0
with torch.no_grad():
for X, y in data_iter:
if isinstance(net, torch.nn.Module):
net.eval() # 将模型net调成 评估模式,这会关闭dropout
# 累加这一个batch数据中判断正确的个数
acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
net.train() # 将模型net调回 训练模式
else: # 针对自定义的模型(几乎用不到)
if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数
# 将 is_training 设置成False
acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item()
else:
acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
n += y.shape[0]
return acc_sum / n
def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):
net = net.to(device)
print('training on ', device)
batch_count = 0
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
for X, y in train_iter:
X = X.to(device)
y = y.to(device)
y_hat = net(X)
l = loss(y_hat, y)
optimizer.zero_grad()
l.backward()
optimizer.step()
train_l_sum += l.cpu().item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
n += y.shape[0]
batch_count += 1
test_acc = evaluate_accuracy(test_iter, net)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
% (epoch+1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
模型2:从头训练(MobileNet)
模型3:预训练模型+微调(MobileNet)
保存模型
前向传播可视化
(绘图相关方法参考)
测试集评估模型效果
文章出处登录后可见!
已经登录?立即刷新