利用pytorch 模型载入部分权重

1. 介绍

本文介绍如何在pytorch中载入模型的部分权重, 总结了2个比较常见的问题:

  • 第1个常见的问题: 在分类网络中,当载入的预训练权重的全连接层与我们自己实例化模型的节点个数不一样时,该如何载入?

比如在花卉数据集分类时只有5类,所以最后一层全连接层节点个数为5,但是我们载入的预训练权重是针对ImageNet-1k的权重,它的全连接层节点个数是1000,很明显是不能直接载入预训练模型权重的。

  • 第2个常见的问题: 如果对网络的结构进行了一定的修改,修改之后很明显是不能直接载入预训练权重了。

能不能载入部分权重呢? 当然这要看你对网络是如何修改的,如果你是在网络的高层进行结构的修改的话,那么相对底层的没有被修改过的权重还是可以载入的,因为底层都是比如Backbone都是比较通用的权重,载入之后对我们的训练是很有帮助的。

2. 代码实现说明

以分类网络ResNet为例说明,对应项目中的load_weights.py来介绍对部分权重进行载入。

import os
import torch
import torch.nn as nn
from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)

    # option1
    net = resnet34()
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)

    # option2
    # net = resnet34(num_classes=5)
    # pre_weights = torch.load(model_weight_path, map_location=device)
    # del_key = []
    # for key, _ in pre_weights.items():
    #     if "fc" in key:
    #         del_key.append(key)
    #
    # for key in del_key:
    #     del pre_weights[key]
    #
    # missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
    # print("[missing_keys]:", *missing_keys, sep="\n")
    # print("[unexpected_keys]:", *unexpected_keys, sep="\n")


if __name__ == '__main__':
    main()

下载官方提供的ResNet34预训练模型, 并将它命名为resnet34-pre.pth,接下来介绍官方提供的载入部分权重的方法。

2. 1 pytorch 官方提供方法

  • 首先实例化resnet34模型,注意并没有传入num_classes参数,此时默认的num_classes=1000,此时就可以直接载入官方的预训练权重。因为我们使用的是默认的全连接层个数1000,与预训练权重是一致的。
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)

# option1
net = resnet34()
net.load_state_dict(torch.load(model_weight_path, map_location=device))
  • 由于我们自己的分类个数是不等于1000的,比如我们这里的分类个数为5,接下来该怎么办呢?首先查看resnet34模型搭建的源码。可以看到全连接层是通过sef.fc=nn.Linear(512*block.expansion,num_class)这条语句实现的。
    在这里插入图片描述
    点开nn.Linear类,可以看到它有这么几个参数self.in_featuresself.out_features,分别表示全连接层的输入和输出的节点个数。对于imagenet-1k,输出节点个数self.out_features对应的就是1000. 因此我们可以通过fc.in_features获得网络的输入节点个数,然后输出节点个数定义为我们自己的分类个数5
net.fc=nn.Linear(in_channel,5)

通过创建新的全连接层来替换原来的全连接层。这样我们就变相的载入了Conv1layer4_x的层结构,替换掉全连接层相当于没有载入全连接层权重,刚好符合我们的要求
在这里插入图片描述

2. 2 另外一种实现方式

net = resnet34(num_classes=5)
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():
    if "fc" in key:
        del_key.append(key)

for key in del_key:
    del pre_weights[key]

missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
print("[missing_keys]:", *missing_keys, sep="\n")
print("[unexpected_keys]:", *unexpected_keys, sep="\n")
  • 首先实例化resnet34,这里需要注意的是我们传入了num_classes参数,也就是最后一个全连接层节点个数一开始就设置为5了。此时就不能像前一种方法一样直接通过net.load_state_dict(torch.load(model_weight_path, map_location=device))载入预训练权重了。因为网络的全连接层节点个数和预训练模型是不一样的,直接载入就会报错。我们应该怎么办呢?
  • 通过(torch.load(model_weight_path, map_location=device),先读取预训练权重保存为一个有序字典Orderedict的形式。每个键值对对应一组参数和权重。
    在这里插入图片描述
  • 由于我们只想保留除全连接层fc之外的预训练权重,我们可以通过遍历pre_weights字典,去删减掉不需要的键值对。通过点击resnet34查看构建的代码,可以看到,其全连接层为self.fc包含了fc字段。除此之外,也可以通过实例化后的模型,调用state_dict()函数,查看模型的所有模型权重的key和value值:
net = resnet34(num_classes=5)
net_weights = net.state_dict()

在这里插入图片描述

  • 可以看到全连接层包含两个权重,分别是fc.weightfc.bias ,此时我们可以遍历pre_weights的每个key值,如果key中包含有fc这个字段我们就可以知道它是属于全连接层的权重,后续把包含fc的权重删除掉,然后我们再去载入剩下的权重。
  • 我们实例化的模型和载入的模型,他们权重的名称(key值)要是一样的才可以载入和方便删减。还有一种情况可能载入模型的key与实例化的模型中的key值不一样。那么这种情况的话就会比较麻烦点。那么就需要将载入模型的key值跟实例化一一对应,将载入模型的key改为实例化模型的key值。这就需要你对网络搭建过程非常清楚,你要知道每个层它所对应的权重是什么,这样的话就可以编辑有序字典中的key来载入你想载入的权重。这个例子我们载入的权重和我们创建的模型它的key值都是一样的,因此相对于刚才说的这种情况,载入会比较简单些。
  • 上面的例子,只要包含了fc字段,我们就将这个key值先存到del_key列表中。通过调试可以发现del_key存的就是fc_weightsfc_bias。紧接着我们再遍历del_key依次将这些key从pre_weights字典中删除。
   pre_weights = torch.load(model_weight_path, map_location=device)
    del_key = []
    for key, _ in pre_weights.items():
        if "fc" in key:
            del_key.append(key)
    
    for key in del_key:
        del pre_weights[key]
  • 这里需要注意,在载入预训练权重的时候,我们多传入了一个参数strict=False, 如果你不传的话,它默认是为True的。如果strict=True它会严格的载入每个key值,因为我们删减掉全连接中的权重,因此就不能将strict设置为True。net.load_state_dict(pre_weights, strict=False)会返回两个 变量,分别是missing_keysunexpected_keys
    • missing_key:表示在我们实例化的模型net中有部分权重并没有在pre_weights预训练权重中出现,就相当于与pre_weights中漏掉了这些权重。
    • unexpected_key:就是说在我们载入的pre_weights中有一部分权重它不在我们的net中,此时就会存在unexpected_keys中。针对我们刚才讲的情况,应该会出现两个missing_key :fc.weights和fc.bias:
      missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
      print("[missing_keys]:", *missing_keys, sep="\n")
      print("[unexpected_keys]:", *unexpected_keys, sep="\n")
    
    执行以后打印的信息:
    >>  [missing_keys]:
    >>   fc.weight
    >>   fc.bias
    >>   [unexpected_keys]:
    
    可以看到missing_key中有fc.weights和fc.bias,在unexpected_keys中是没有任何参数的。也就时除了fc.weights和fc.bias两个全连接参数外,其他参数都载入进来了。

如果有些人,除了fc层外还改动了某些高层的结构如resnet中Conv5_x,我们如何去载入低层没有改动的权重呢?: 此时对于resnet模型就需要载入除了Conv5_xfc层之外的所有权重
在这里插入图片描述
此时我们可以在条件中,判断key是否包含layer4,如果有的话也将它删掉。

 pre_weights = torch.load(model_weight_path, map_location=device)
 del_key = []
 for key, _ in pre_weights.items():
     if "fc" in key or "layer4" in key:
         del_key.append(key)
 
 for key in del_key:
     del pre_weights[key]

在这里插入图片描述
执行之后,我们发现在missing_key列表中除了我们之前两个全连接层权重之外,剩下,剩下的都是layer4所对应的权重,也就是说我们也没有将layer4所对应的权重载入进去。

总结

以上介绍的是2种比较常见的载入部分权重的方法,除了我们讲到的在载入的权重的有序字典筛选之外,我们可以自己新创建一个字典,新创建一个字典之后,可以自己组建key,value然后用上文介绍的方法进行载入就可以了,这样的话会更加的灵活.

  • 在这里感谢B站霹雳吧啦Up主

代码链接:https://pan.baidu.com/s/1j34QBVb9ZKxWX7d1Vm9QrQ?pwd=stxx
提取码:stxx

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2023年3月4日 下午4:23
下一篇 2023年3月4日 下午4:26

相关推荐