参考博客:
ModuleNotFoundError: No module named ‘models‘解决torch.load问题【天坑】
保存与加载
使用 torch.save(model, “my_model.pth”) 命令可以保存整个模型。
这个保存/加载过程使用最直观的语法,涉及的代码最少。
以这种方式保存模型将使用Python的 pickle 模块保存整个model。
但是,在进行torch.load(“my_model.pth”)时,加载目录与保存目录要相同,这里的目录不是"my_model.pth",而是项目中定义模型所涉及的目录。举个例子:
load-test工程中有model_1文件夹,model_1文件夹中有yolo.py模块。
yolo.py:
import torch
class MyNet(torch.nn.Module):
def __init__(self):
super(MyNet, self).__init__() # 第一句话,调用父类的构造函数
self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
self.relu1 = torch.nn.ReLU()
self.max_pooling1 = torch.nn.MaxPool2d(2, 1)
self.conv2 = torch.nn.Conv2d(32, 32, 3, 1, 1)
self.relu2 = torch.nn.ReLU()
self.max_pooling2 = torch.nn.MaxPool2d(2, 1)
self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
self.dense2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.max_pooling1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.max_pooling2(x)
x = self.dense1(x)
x = self.dense2(x)
return x
在train.py中导入模型并保存整个模型:
from model_1.yolo import MyNet
import torch
net = MyNet()
torch.save(net, './weights/net.pt') # 保存
将net.pt保存在weights文件夹中。
然后在train.py中加载模型:
import torch
model = torch.load('./weights/net.pt')
是成功的(也不需要 from model_1.yolo import MyNet )。
如果将model_1文件夹改为model_2文件夹,则出错。
出错的原因是导入模型时的目录与保存时的目录不一致。
解决方案
有时候需要将训练好的模型导入另一个工程中,但是该工程的文件夹与原文件夹相同且不方便更改,这时可以采取: 先加载模型,然后保存模型参数,再加载模型参数 的方法来正确加载模型。
先保持model_1文件夹不变。
train.py:
import torch
model = torch.load('./weights/net.pt')
torch.save(model.state_dict(), './weights/net_state_dict.pt') # 保存
将model_1文件夹改为model_2文件夹。
train.py:
from model_2.yolo import MyNet
import torch
model = MyNet()
state_dict = torch.load('./weights/net_state_dict.pt')
model.load_state_dict(state_dict)
加载成功。
文章出处登录后可见!
已经登录?立即刷新