PyTorch:通过pth文件查看网络结构(查看输入输出维度)

pth模型保存时是按照“整个模型保存”和“只保存模型参数”会影响模型的加载和访问方式

保存方式为“整个模型”(torch.save(model, PATH)):

import torch
if __name__ == '__main__':
    model_pth = r'D:\${modelPath}\${modelName}.pth'
    net = torch.load(model_pth, map_location=torch.device('cpu'))
    for key, value in net["state_dict"].items():
        print(key,value.size(),sep="  ")

输出(部分截图)为:

保存方式为“只保存模型参数”(torch.save(model.state_dict(), PATH)):

待补充

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
乘风的头像乘风管理团队
上一篇 2023年8月6日
下一篇 2023年8月6日

相关推荐