[PyTorch] 加载部分模型权重

0. 引言

在实际使用中,我们通常希望有一个预训练模型帮助加速训练,如果使用原生的模型,直接使用加载即可。但我们经常会根据不同的任务要求进行backbone的修改,此时直接加载预训练模型就会出错。因此为了解决该问题,下面引入如何加载部分模型的权重(修改的部分不需要加载)。

1. 分类网络最后一层

一般PyTorch官方实现的网络中,训练集使用的ImageNet数据集,所以分类数(1000)与我们的任务分类数是不同的,所以我们不能直接载入最后一层的权重。需对最后的分类层进行处理。

1.1 [key一致] 预训练模型权重的名称和要训练模型权重的名称一致

1.1.1 方法一(PyTorch官方实现)

以ResNet-50为例:

import os
import torch
import torch.nn as nn
from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)

    # option1(官方实现)
    net = resnet34()  # 这里没有传入num_classes参数,默认为1000
    """
        因为net的分类数与预训练权重的分类树是一致的,所以接下来可以直接加载预训练权重了
    """
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    # change fc layer structure
    in_channel = net.fc.in_features  # 获取最后一层(FC)的输入节点个数
    """
        创建一个新的FC(只是修改了输出通道数),注意,这里相当与是覆盖操作,所以这个FC层的参数不是之前预训练权重的FC
        而是一个新的FC(里面的参数是经过随机初始化后的)
    """
    net.fc = nn.Linear(in_channel, 5)


if __name__ == '__main__':
    main()

1.1.2 方法二(修改预训练权重的有序字典中的key)

import os
import torch
import torch.nn as nn
from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)

    # option2
    net = resnet34(num_classes=5)  # 直接传入了自己分类个数(与预训练权重的分类个数不一致)
    """
        因为模型的分类数与预训练的权重分类数不一致,因此不能直接加载预训练权重(会报错)
        所以应该分两步:
            1. 读取预训练权重 -> torch.load()
                读取到权重是一个有序字典。那么为了实现加载该权重不报错,我们可以对这个权重字典进行删减
                首先我们可以通过 net.state_dict()方法来读取模型中所有权重的有序字典(建议使用debug而不是print)
                    我们可以看到模型的有序字典中,倒数第一二的key:
                        'fc.weight'
                        'fc.bias'
            2. 拿到我们想要修改的key,就可以对预训练权重的有序字典进行删减了
    """
    net_weight = net.state_dict()
    pre_weights = torch.load(model_weight_path, map_location=device)
    del_key = []
    for key, _ in pre_weights.items():  # 遍历预训练权重的有序字典
        if "fc" in key:  # 如果key中包含'fc'这个字段
            del_key.append(key)

    for key in del_key:  # 遍历要删除字段的list
        del pre_weights[key]  # 删除预训练权重的key和对应的value

    """
        获得我们想要的预训练权重后,进行权重加载
            这里需要多写一个额外的参数 -> strict=False,即不严格载入每一个key
        
        当strict=False后会返回两个变量:
            ① missing_keys:net网络中的一部分权重并没有在预训练权重中出现(漏掉了一些权重) -> 预训练模型多出来的权重(key)
            ② unexpected_keys:预训练权重的一部分并没有net对应的key(多出来的一些权重)   -> 预训练模型少的权重(key)
    """
    missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
    print("[missing_keys]:", *missing_keys, sep="\n")
    print("[unexpected_keys]:", *unexpected_keys, sep="\n")

    """
        Res:
            [missing_keys]:
            fc.weight
            fc.bias
            [unexpected_keys]:
    """


if __name__ == '__main__':
    main()

1.2 [key不一致] 预训练模型权重名称与要训练模型权重的不一致

这时就需要我们自行就预训练权重的名称进行修改(修改有序字典的key值),这必须对模型非常熟悉。

Note: 加载权重其实就是两个有序字典相互匹配的过程,对于每一个字典的value,我们是不用管的(也管不了),我们只需要让二个dict的key对应上就可以

2. 对网络结构进行修改

是否可以加载预训练模型也是需要看我们是怎么修改网络的:

  • 如果我们修改的是高层,那么没有改动的底层的权重是可以加载的(因为底层的权重都是一些比较通用的)
  • 如果我们把底层修改了,那么就不建议加载预训练权重了

对于ResNet-50,如果我们将layer4进行了修改,那么我们需要处理两部分:

  1. layer4的key
  2. fc的key(如果分类数变了就需要修改)

核心代码

del_key = []
for key, _ in pre_weights.items():  # 遍历预训练权重的有序字典
    if "fc" in key or "layer4" in key:  # 多个条件,in的优先级大于or
        del_key.append(key)

完整代码

import os
import torch
import torch.nn as nn
from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)

    net = resnet34(num_classes=5)  # 直接传入了自己分类个数(与预训练权重的分类个数不一致)
    net_weight = net.state_dict()
    pre_weights = torch.load(model_weight_path, map_location=device)
    del_key = []
    for key, _ in pre_weights.items():  # 遍历预训练权重的有序字典
        if "fc" in key or "layer4" in key:  # 多个条件,in的优先级大于or
            del_key.append(key)

    for key in del_key:  # 遍历要删除字段的list
        del pre_weights[key]  # 删除预训练权重的key和对应的value

    missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
    print("[missing_keys]:", *missing_keys, sep="\n")
    print("[unexpected_keys]:", *unexpected_keys, sep="\n")

    """
        Res:
            [missing_keys]:
            layer4.0.conv1.weight
            layer4.0.bn1.weight
            layer4.0.bn1.bias
            layer4.0.bn1.running_mean
            layer4.0.bn1.running_var
            layer4.0.conv2.weight
            layer4.0.bn2.weight
            layer4.0.bn2.bias
            layer4.0.bn2.running_mean
            layer4.0.bn2.running_var
            layer4.0.downsample.0.weight
            layer4.0.downsample.1.weight
            layer4.0.downsample.1.bias
            layer4.0.downsample.1.running_mean
            layer4.0.downsample.1.running_var
            layer4.1.conv1.weight
            layer4.1.bn1.weight
            layer4.1.bn1.bias
            layer4.1.bn1.running_mean
            layer4.1.bn1.running_var
            layer4.1.conv2.weight
            layer4.1.bn2.weight
            layer4.1.bn2.bias
            layer4.1.bn2.running_mean
            layer4.1.bn2.running_var
            layer4.2.conv1.weight
            layer4.2.bn1.weight
            layer4.2.bn1.bias
            layer4.2.bn1.running_mean
            layer4.2.bn1.running_var
            layer4.2.conv2.weight
            layer4.2.bn2.weight
            layer4.2.bn2.bias
            layer4.2.bn2.running_mean
            layer4.2.bn2.running_var
            fc.weight
            fc.bias
            [unexpected_keys]:
    """


if __name__ == '__main__':
    main()

参考

  1. https://www.bilibili.com/video/BV1dA411g7Ao?spm_id_from=333.999.0.0

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2022年5月10日
下一篇 2022年5月10日

相关推荐