torchvision.models简介

torchvision.models简介

  • 1 torchvision.models介绍
    • 1.1 torchvision介绍
    • 1.2 torchvision.models
  • 2 导入模型举例
    • 2.1 模型的使用
    • 2.2 模型的修改
    • 2.3 模型的保存和读取

1 torchvision.models介绍

1.1 torchvision介绍

PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasets、torchvision.models、torchvision.transforms

该篇主要介绍torchvision.models,关于torchvision.datasets 和 torchvision.transforms 可以看以下几篇:
https://blog.csdn.net/Alexa_/article/details/129408512

1.2 torchvision.models

torchvision.models:这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。

torchvision.models模块的子模块中包含以下模型结构:

  • AlexNet
  • VGG: VGG-11, VGG-13, VGG-16, VGG-19
  • ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
  • SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1
  • DenseNet:Densenet-121,Densenet-169,Densenet-161,Densenet-201

(1),预训练模型可以通过设置pretrained=True来构建

import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)

  • 预训练模型期望的输入是RGB图像的mini-batch:(batch_size, 3, H, W),并且H和W不能低于224。
  • 图像的像素值必须在范围[0,1]间,并且用均值mean=[0.485, 0.456, 0.406]和方差std=[0.229, 0.224, 0.225]进行归一化。

下载的模型可以通过state_dict() 来打印状态参数、缓存的字典,如下所示:

import torchvision.models as models
vgg16 = models.vgg16(pretrained=True)
# 返回包含模块所有状态的字典,包括参数和缓存
pretrained_dict = vgg16.state_dict()

(2),只需要网络结构,不加载参数来初始化,可以将pretrained = False

model = torchvision.models.densenet169(pretrained=False)
# 等价于:
model = torchvision.models.densenet169()

2 导入模型举例

应用VGG16模型,并进行改动,以适应CIDIAR10数据集

  • CIFAR10数据集是 10个类别
  • VGG16输出是1000个类别
  • VGG 加一层输出10个类别

2.1 模型的使用

导入模型,输出查看网络结构:

import torchvision
# 直接调用,实例化模型,pretrained代表是否下载预先训练好的参数
vgg16_false = torchvision.models.vgg16(pretrained = False)
vgg16_ture = torchvision.models.vgg16(pretrained = True)
print(vgg16_ture)

输出结果,可以看到VGG16的结构,可以看出,其最后一行 out_features = 1000.

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
    (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

2.2 模型的修改

  • 修改方式1:如在classifier 新增一层全连接层, 使用 add.module()
vgg16_ture.classifier.add_module("add_linear",torch.nn.Linear(1000,10)) # 在vgg16的classfier里加一层
print(vgg16_ture)

输出之后,在classifier里,可以看到,最后一行:out_features = 10

 (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
    (add_linear): Linear(in_features=1000, out_features=10, bias=True)
  )
  • 修改方式2:直接修改对应层,编码相对应
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096,10) # 修改对应层,编号相对应
print(vgg16_false)

修改之前与修改之后对比,可以看到classifier中第6行的变化:

#修改前:
(classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )

#修改后:
(classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )


2.3 模型的保存和读取

模型的保存与存取,两种方式

  • 方式1:直接 保存和加载模型:

  • 方式2:字典方式只保存参数:j加载模型还的调用原来的神经网络

import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained = False)

# 保存方法 1 
torch.save(vgg16,"vgg16_method1.pth") # 保存结构模型和参数、保存路径
# 加载模型 1
model = torch.load("vgg16_method1.pth")

# 保存方式 2 -- 以字典方式只保存参数(官方推荐),
torch.save(vgg16.state_dict(),"vgg_method2.pth") 
# 加载方式 2 -- 要恢复网络模型
model = torch.load("vgg_method2.pth")
vgg16 = torchvision.models.vgg16(pretrained = True)
vgg16.load_state_dict(torch.load("vgg_method2.pth"))



参考:

https://blog.csdn.net/shanshangyouzhiyangM/article/details/83143740

https://blog.csdn.net/weixin_44953928/article/details/124216322

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
乘风的头像乘风管理团队
上一篇 2023年9月17日
下一篇 2023年9月17日

相关推荐