0. 引言
在实际使用中,我们通常希望有一个预训练模型帮助加速训练,如果使用原生的模型,直接使用加载即可。但我们经常会根据不同的任务要求进行backbone的修改,此时直接加载预训练模型就会出错。因此为了解决该问题,下面引入如何加载部分模型的权重(修改的部分不需要加载)。
1. 分类网络最后一层
一般PyTorch官方实现的网络中,训练集使用的ImageNet数据集,所以分类数(1000)与我们的任务分类数是不同的,所以我们不能直接载入最后一层的权重。需对最后的分类层进行处理。
1.1 [key一致] 预训练模型权重的名称和要训练模型权重的名称一致
1.1.1 方法一(PyTorch官方实现)
以ResNet-50为例:
import os
import torch
import torch.nn as nn
from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
# option1(官方实现)
net = resnet34() # 这里没有传入num_classes参数,默认为1000
"""
因为net的分类数与预训练权重的分类树是一致的,所以接下来可以直接加载预训练权重了
"""
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# change fc layer structure
in_channel = net.fc.in_features # 获取最后一层(FC)的输入节点个数
"""
创建一个新的FC(只是修改了输出通道数),注意,这里相当与是覆盖操作,所以这个FC层的参数不是之前预训练权重的FC
而是一个新的FC(里面的参数是经过随机初始化后的)
"""
net.fc = nn.Linear(in_channel, 5)
if __name__ == '__main__':
main()
1.1.2 方法二(修改预训练权重的有序字典中的key)
import os
import torch
import torch.nn as nn
from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
# option2
net = resnet34(num_classes=5) # 直接传入了自己分类个数(与预训练权重的分类个数不一致)
"""
因为模型的分类数与预训练的权重分类数不一致,因此不能直接加载预训练权重(会报错)
所以应该分两步:
1. 读取预训练权重 -> torch.load()
读取到权重是一个有序字典。那么为了实现加载该权重不报错,我们可以对这个权重字典进行删减
首先我们可以通过 net.state_dict()方法来读取模型中所有权重的有序字典(建议使用debug而不是print)
我们可以看到模型的有序字典中,倒数第一二的key:
'fc.weight'
'fc.bias'
2. 拿到我们想要修改的key,就可以对预训练权重的有序字典进行删减了
"""
net_weight = net.state_dict()
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items(): # 遍历预训练权重的有序字典
if "fc" in key: # 如果key中包含'fc'这个字段
del_key.append(key)
for key in del_key: # 遍历要删除字段的list
del pre_weights[key] # 删除预训练权重的key和对应的value
"""
获得我们想要的预训练权重后,进行权重加载
这里需要多写一个额外的参数 -> strict=False,即不严格载入每一个key
当strict=False后会返回两个变量:
① missing_keys:net网络中的一部分权重并没有在预训练权重中出现(漏掉了一些权重) -> 预训练模型多出来的权重(key)
② unexpected_keys:预训练权重的一部分并没有net对应的key(多出来的一些权重) -> 预训练模型少的权重(key)
"""
missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
print("[missing_keys]:", *missing_keys, sep="\n")
print("[unexpected_keys]:", *unexpected_keys, sep="\n")
"""
Res:
[missing_keys]:
fc.weight
fc.bias
[unexpected_keys]:
"""
if __name__ == '__main__':
main()
1.2 [key不一致] 预训练模型权重名称与要训练模型权重的不一致
这时就需要我们自行就预训练权重的名称进行修改(修改有序字典的key值),这必须对模型非常熟悉。
Note: 加载权重其实就是两个有序字典相互匹配的过程,对于每一个字典的value,我们是不用管的(也管不了),我们只需要让二个dict的key对应上就可以
2. 对网络结构进行修改
是否可以加载预训练模型也是需要看我们是怎么修改网络的:
- 如果我们修改的是高层,那么没有改动的底层的权重是可以加载的(因为底层的权重都是一些比较通用的)
- 如果我们把底层修改了,那么就不建议加载预训练权重了
对于ResNet-50,如果我们将layer4
进行了修改,那么我们需要处理两部分:
layer4
的keyfc
的key(如果分类数变了就需要修改)
核心代码:
del_key = []
for key, _ in pre_weights.items(): # 遍历预训练权重的有序字典
if "fc" in key or "layer4" in key: # 多个条件,in的优先级大于or
del_key.append(key)
完整代码:
import os
import torch
import torch.nn as nn
from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
net = resnet34(num_classes=5) # 直接传入了自己分类个数(与预训练权重的分类个数不一致)
net_weight = net.state_dict()
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items(): # 遍历预训练权重的有序字典
if "fc" in key or "layer4" in key: # 多个条件,in的优先级大于or
del_key.append(key)
for key in del_key: # 遍历要删除字段的list
del pre_weights[key] # 删除预训练权重的key和对应的value
missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
print("[missing_keys]:", *missing_keys, sep="\n")
print("[unexpected_keys]:", *unexpected_keys, sep="\n")
"""
Res:
[missing_keys]:
layer4.0.conv1.weight
layer4.0.bn1.weight
layer4.0.bn1.bias
layer4.0.bn1.running_mean
layer4.0.bn1.running_var
layer4.0.conv2.weight
layer4.0.bn2.weight
layer4.0.bn2.bias
layer4.0.bn2.running_mean
layer4.0.bn2.running_var
layer4.0.downsample.0.weight
layer4.0.downsample.1.weight
layer4.0.downsample.1.bias
layer4.0.downsample.1.running_mean
layer4.0.downsample.1.running_var
layer4.1.conv1.weight
layer4.1.bn1.weight
layer4.1.bn1.bias
layer4.1.bn1.running_mean
layer4.1.bn1.running_var
layer4.1.conv2.weight
layer4.1.bn2.weight
layer4.1.bn2.bias
layer4.1.bn2.running_mean
layer4.1.bn2.running_var
layer4.2.conv1.weight
layer4.2.bn1.weight
layer4.2.bn1.bias
layer4.2.bn1.running_mean
layer4.2.bn1.running_var
layer4.2.conv2.weight
layer4.2.bn2.weight
layer4.2.bn2.bias
layer4.2.bn2.running_mean
layer4.2.bn2.running_var
fc.weight
fc.bias
[unexpected_keys]:
"""
if __name__ == '__main__':
main()
参考
- https://www.bilibili.com/video/BV1dA411g7Ao?spm_id_from=333.999.0.0
文章出处登录后可见!
已经登录?立即刷新