PyTorch学习笔记:使用state_dict来保存和加载模型

1. state_dict简介

state_dict是Python的字典对象,可用于保存模型参数、超参数以及优化器(torch.optim)的状态信息。需要注意的是,只有具有可学习参数的层(如卷积层、线性层等)才有state_dict。

下面就拿官方教程中的一个小示例来说明state_dict的使用:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 初始化模型
model = TheModelClass()

# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 打印模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# 打印优化器的状态字典
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

让我们来运行一下以上代码:

从以上代码及运行结果可知,state_dict将模型的每一层映射到一个参数张量。在Python中,可以对state_dict进行保存、加载、更新、修改等操作。

下面我们就来看一下PyTorch如何通过state_dict来保存和加载模型。

2. 保存和加载state_dict

可以通过torch.save()来保存模型的state_dict,即只保存学习到的模型参数,并通过load_state_dict()来加载并恢复模型参数。PyTorch中最常见的模型保存扩展名为’.pt’或’.pth’。

下面我们就将上个例子中构造的简单模型TheModelClass的参数保存在state_dict,然后通过load_state_dict()来加载模型参数。

......

# 将模型保存到当前路径,名称为test_state_dict.pth
PATH = './test_state_dict.pth'
torch.save(model.state_dict(), PATH)

model = TheModelClass()    # 首先通过代码获取模型结构
model.load_state_dict(torch.load(PATH))   # 然后加载模型的state_dict
model.eval()

注意:load_state_dict()函数只接受字典对象,不可直接传入模型路径,所以需要先使用torch.load()反序列化已保存的state_dict。

另外,在使用模型做推理之前,需要调用model.eval()函数将dropout和batch normalization层设置为评估模式,否则会导致模型推理结果不一致。 

当然,除了保存state_dict,PyTorch还支持保存和加载整个模型。

3. 保存和加载完整模型

保存和加载整个模型的代码如下:

# 保存完整模型
torch.save(model, PATH)

# 加载完整模型
model = torch.load(PATH)
model.eval()

这种方式虽然代码看起来较state_dict方式要简洁,但是灵活性会差一些。因为torch.save()函数使用Python的pickle模块进行序列化,但pickle无法保存模型本身,而是保存包含类的文件路径,该文件会在模型加载时使用。所以当在其他项目对模型进行重构之后,就可能会出现意想不到的错误。

4. 保存和加载checkpoint用于继续训练或推理

除了以上两种保存模型的方式,PyTorch还支持以checkpoint方式保存模型训练的中间结果,以实现模型的继续训练或者推理。这种方式下,保存的内容不仅包含模型的state_dict,还会保存优化器的state_dict,以及其他参数如loss、epoch等。

保存:

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

加载:

model = TheModelClass()
optimizer = TheOptimizerClass()

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()

model.train()

checkpoint在PyTorch中常保存为.tar的文件扩展名。

注:以上checkpoint保存和加载的代码未经本人测试。

5. 迁移学习下的热启动模式

我们在工程中,常常用到迁移学习,利用训练好的模型在新的数据集上进行迁移训练,可达到使用少量数据进行快速训练的目的。

在迁移学习中,我们常常需要对预训练模型进行部分加载的需要,这个时候我们就要用到热启动模式,可通过在load_state_dict()函数中将strict参数设置为False来忽略非匹配键的参数。

# 保存模型state_dict
torch.save(modelA.state_dict(), PATH)

# 热加载模型
modelB = TheModelBClass()
modelB.load_state_dict(torch.load(PATH), strict=False)

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2022年5月11日
下一篇 2022年5月11日

相关推荐