温馨提示×

温馨提示×

您好,登录后才能下订单哦!

密码登录×
登录注册×
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》

pytorch加载模型遇到的问题怎么解决

发布时间:2022-03-18 16:59:47 来源:亿速云 阅读:487 作者:iii 栏目:大数据

这篇文章主要讲解了“pytorch加载模型遇到的问题怎么解决”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“pytorch加载模型遇到的问题怎么解决”吧!

1. 查看网络参数

pretrained_dict1 = torch.load(model_path2, map_location='cpu')['state_dict']#预训练文件后缀是.tarpretrained_dict2 = torch.load(model_path3)#预训练文件后缀是.pth#1.查看预训练网络参数for key ,value in pretrained_dict1.items():#pretrained_dict1,pretrained_dict2就是上面的东西count+=1print(key)print(count)#2.查看model的网络参数for key ,value in model.state_dict.items():print(key,value)

2. 加载模型遇到的两大问题

1. 模型的键不匹配

以下两代码,解决了键不匹配问题,一个是删除键的某一部分,一是添加键的某一部分

例:
下面的错误是因为模型的model.state_dict().items()的键是conv1.weight,预训练的键是module.conv1.weight,导致不匹配。所以下面的代码是让module. 去掉
pytorch加载模型遇到的问题怎么解决

1.删除键的头部
pretrained_dict = {
   
   
   k.replace('module.', ''): v for k, v in pretrained_dict2.items()}

当然有时候自己model的键需要改进,如下

2.补齐键的头部
checkpoint={
   
   
   'module.'+k:v for k,v in pretrained_dict.items()}

2. 预训练模型和自己的model长度不一样

# 删除pretrained_dict.items()中model所没有的东西model_dict = model.state_dict()pretrained_dict = {
   
   
   k: v for k, v in pretrained_dict.items() if k in model_dict}  # 只保留预训练模型中,自己建的model有的参数model_dict.update(pretrained_dict)  # 将预训练的值,更新到自己模型的dict中model.load_state_dict(model_dict)  # model加载dict中的数据,更新网络的初始值

3. 通过查看加载参数,看是否加载成功

for value1 ,value2 in zip(checkpoint.items(), model.state_dict().items()):print(value1,value2)

如下所示,model的参数和预训练的参数是一样的
pytorch加载模型遇到的问题怎么解决

4. 案例

(这里处理的只是针对本人的model加载的情况,要想正确加载,还需遵守上面3步)

    def load_param(self, model_path):#这里的self就是modelmodel_dict = self.state_dict()pretrained_dict = torch.load(model_path)#这里model_path的后缀是.pth可直接读取# pretrained_dict = {k.replace('module.', ''): v for k, v in#                    pretrained_dict.items()}  # 因为pretrained_dict得到module.conv1.weight,但是自己建的model无module,只是conv1.weight,所以改写下pretrained_dict = {
   
   
   k: v for k, v in pretrained_dict.items() if k in model_dict}  # 只保留预训练模型中,自己建的model有的参数model_dict.update(pretrained_dict)  # 将预训练的值,更新到自己模型的dict中self.load_state_dict(model_dict)  # model加载dict中的数据,更新网络的初始值

感谢各位的阅读,以上就是“pytorch加载模型遇到的问题怎么解决”的内容了,经过本文的学习后,相信大家对pytorch加载模型遇到的问题怎么解决这一问题有了更深刻的体会,具体使用情况还需要大家实践验证。这里是亿速云,小编将为大家推送更多相关知识点的文章,欢迎关注!

向AI问一下细节

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

AI