
Pytorch模型参数的保存和加载

目录
- 一、前言
- 二、参数保存
- 三、参数的加载
- 四、保存和加载整个模型
- 五、总结
一、前言
在模型训练完成后,我们需要保存模型参数值用于后续的测试过程。由于保存整个模型将耗费大量的存储,故推荐的做法是只保存参数,使用时只需在建好模型的基础上加载。
通常来说,保存的对象包括网络参数值、优化器参数值、epoch值等。本文将简单介绍保存和加载模型参数的方法,同时也给出保存整个模型的方法供大家参考。
二、参数保存
在这里我们使用 torch.save() 函数保存模型参数:
import torch path = './model.pth' torch.save(model.state_dict(), path)
model——指定义的模型实例变量,如model=net( )
state_dict()——state_dict( )是一个可以轻松地保存、更新、修改和恢复的python字典对象, 对于model来说,表示模型的每一层的权重及偏置等参数信息;对于 optimizer 来说,其包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)
path——path是保存参数的路径,一般设置为 path='./model.pth' , path='./model.pkl'等形式。
此外,如果想保存某一次训练采用的optimizer、epochs等信息,可将这些信息组合起来构成一个字典保存起来:
import torch path = './model.pth' state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch} torch.save(state, path)
三、参数的加载
使用 load_state_dict()函数加载参数到模型中, 当仅保存了模型参数,而没有optimizer、epochs等信息时:
model.load_state_dict(torch.load(path))
model——事先定义好的跟原模型一致的模型
path——之前保存的模型参数文件
如若保存了optimizer、epochs等信息,我们这样载入信息:
# 使用torch.load()函数将文件中字典信息载入 state_dict 变量中 state_dict = torch.load(path) # 分布加载参数到模型和优化器 model.load_state_dict(state_dict['model']) optimizer.load_state_dict(state_dict['optimizer']) epoch = state_dict(['epoch'])
我们还可以在每n个epoch后保存一次参数,以观察不同迭代次数模型的表现。此时我们可设置不同的path,如 path='./model' + str(epoch) +'.pth',这样,不同epoch的参数就能保存在不同的文件中。
四、保存和加载整个模型
使用上文提到的方法即可:
torch.save(model, path) model = torch.load(path)
五、总结
pytorch中state_dict()和load_state_dict()函数配合使用可以实现状态的获取与重载,load()和save()函数配合使用可以实现参数的存储与读取。掌握对应的函数使用方法就可以游刃有余地进行运用。
到此这篇关于Pytorch模型参数的保存和加载的文章就介绍到这了,更多相关Pytorch模型参数保存内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!
相关推荐
-
Pytorch中实现只导入部分模型参数的方式
我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected).我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed).如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误.那么在这种情况下,该如何导入模型呢? 好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数
-
基于pytorch的保存和加载模型参数的方法
当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torch.save(net.state_dict(),path): 功能:保存训练完的网络的各层参数(即weights和bias) 其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth) net2.load_state_dict(torch.loa
-
PyTorch计算损失函数对模型参数的Hessian矩阵示例
目录 前言 模型定义 求解Hessian矩阵 前言 在实现Per-FedAvg的代码时,遇到如下问题: 可以发现,我们需要求损失函数对模型参数的Hessian矩阵. 模型定义 我们定义一个比较简单的模型: class ANN(nn.Module): def __init__(self): super(ANN, self).__init__() self.sigmoid = nn.Sigmoid() self.fc1 = nn.Linear(3, 4) self.fc2 = nn.Linear(4
-
pytorch 求网络模型参数实例
用pytorch训练一个神经网络时,我们通常会很关心模型的参数总量.下面分别介绍来两种方法求模型参数 一 .求得每一层的模型参数,然后自然的可以计算出总的参数. 1.先初始化一个网络模型model 比如我这里是 model=cliqueNet(里面是些初始化的参数) 2.调用model的Parameters类获取参数列表 一个典型的操作就是将参数列表传入优化器里.如下 optimizer = optim.Adam(model.parameters(), lr=opt.lr) 言归正传,继续回到参
-
Pytorch 统计模型参数量的操作 param.numel()
param.numel() 返回param中元素的数量 统计模型参数量 num_params = sum(param.numel() for param in net.parameters()) print(num_params) 补充:Pytorch 查看模型参数 Pytorch 查看模型参数 查看利用Pytorch搭建模型的参数,直接看程序 import torch # 引入torch.nn并指定别名 import torch.nn as nn import torch.nn.functio
-
PyTorch和Keras计算模型参数的例子
Pytorch中,变量参数,用numel得到参数数目,累加 def get_parameter_number(net): total_num = sum(p.numel() for p in net.parameters()) trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) return {'Total': total_num, 'Trainable': trainable_num} Kera
-
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
1. 利用resnet18做迁移学习 import torch from torchvision import models if __name__ == "__main__": # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = 'cpu' print("-----device:{}".format(device))
-
PyTorch深度学习模型的保存和加载流程详解
一.模型参数的保存和加载 torch.save(module.state_dict(), path):使用module.state_dict()函数获取各层已经训练好的参数和缓冲区,然后将参数和缓冲区保存到path所指定的文件存放路径(常用文件格式为.pt..pth或.pkl). torch.nn.Module.load_state_dict(state_dict):从state_dict中加载参数和缓冲区到Module及其子类中 . torch.nn.Module.state_dict()函数
-
pytorch模型的保存和加载、checkpoint操作
其实之前笔者写代码的时候用到模型的保存和加载,需要用的时候就去度娘搜一下大致代码,现在有时间就来整理下整个pytorch模型的保存和加载,开始学习把~ pytorch的模型和参数是分开的,可以分别保存或加载模型和参数.所以pytorch的保存和加载对应存在两种方式: 1. 直接保存加载模型 (1)保存和加载整个模型 # 保存模型 torch.save(model, 'model.pth\pkl\pt') #一般形式torch.save(net, PATH) # 加载模型 model = torc
-
解决tensorflow模型参数保存和加载的问题
终于找到bug原因!记一下:还是不熟悉平台的原因造成的! Q:为什么会出现两个模型对象在同一个文件中一起运行,当直接读取他们分开运行时训练出来的模型会出错,而且总是有一个正确,一个读取错误? 而 直接在同一个文件又训练又重新加载模型预测不出错,而且更诡异的是此时用分文件里的对象加载模型不会出错? model.py,里面含有 ModelV 和 ModelP,另外还有 modelP.py 和 modelV.py 分别只含有 ModelP 和 ModeV 这两个对象,先使用 modelP.py 和 m
-
在Keras中实现保存和加载权重及模型结构
1. 保存和加载模型结构 (1)保存为JSON字串 json_string = model.to_json() (2)从JSON字串重构模型 from keras.models import model_from_json model = model_from_json(json_string) (3)保存为YAML字串 yaml_string = model.to_yaml() (4)从YAML字串重构模型 model = model_from_yaml(yaml_string) 2. 保存和
-
keras训练浅层卷积网络并保存和加载模型实例
这里我们使用keras定义简单的神经网络全连接层训练MNIST数据集和cifar10数据集: keras_mnist.py from sklearn.preprocessing import LabelBinarizer from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report from keras.models import Sequential
-
pytorch模型保存与加载中的一些问题实战记录
目录 前言 一.torch中模型保存和加载的方式 1.模型参数和模型结构保存和加载 2.只保存模型的参数和加载——这种方式比较安全,但是比较稍微麻烦一点点 二.torch中模型保存和加载出现的问题 1.单卡模型下保存模型结构和参数后加载出现的问题 2.多卡机器单卡训练模型保存后在单卡机器上加载会报错 3.多卡训练模型保存模型结构和参数后加载出现的问题 三.正确的保存模型和加载的方法 总结 前言 最近使用pytorch训练模型,保存模型后再次加载使用出现了一些问题.记录一下解决方案! 一.torc
-
PyTorch模型的保存与加载方法实例
目录 模型的保存与加载 保存和加载模型参数 保存和加载模型参数与结构 总结 模型的保存与加载 首先,需要导入两个包 import torch import torchvision.models as models 保存和加载模型参数 PyTorch模型将学习到的参数存储在一个内部状态字典中,叫做state_dict.这可以通过torch.save方法来实现.我们导入预训练好的VGG16模型,并将其保存.我们将state_dict字典保存在model_weights.pth文件中. model =
-
Python 保存加载mat格式文件的示例代码
mat为matlab常用存储数据的文件格式,python的scipy.io模块中包含保存和加载mat格式文件的API,使用极其简单,不再赘述:另附简易示例如下: # -*- coding: utf-8 -*- import numpy as np import scipy.io as scio # data data = np.array([1,2,3]) data2 = np.array([4,5,6]) # save mat (data format: dict) scio.savemat(
-
PyTorch模型保存与加载实例详解
目录 一个简单的例子 保存/加载 state_dict(推荐) 保存/加载整个模型 保存加载用于推理的常规Checkpoint/或继续训练 保存多个模型到一个文件 使用其他模型来预热当前模型 跨设备保存与加载模型 总结 torch.save:保存序列化的对象到磁盘,使用了Python的pickle进行序列化,模型.张量.所有对象的字典. torch.load:使用了pickle的unpacking将pickled的对象反序列化到内存中. torch.nn.Module.load_state_di
随机推荐
- 浅谈Xcode9 和iOS11适配和特性
- prototype与jquery下Ajax实现的差别
- 用VBS控制鼠标的实现代码(获取鼠标坐标、鼠标移动、鼠标单击、鼠标双击、鼠标右击)
- IOS定制属于自己的个性头像
- oracle select执行顺序的详解
- JS模拟实现ECMAScript5新增的数组方法
- asp.net利用后台实现直接生成html分页的方法
- php中数据库连接方式pdo和mysqli对比分析
- Go语言中的range用法实例分析
- MySql 5.6.14 winx64配置方法(免安装版)
- JavaScript实现横向滑出的多级菜单效果
- NodeJs中的非阻塞方法介绍
- 一个比较简单的PHP 分页分组类
- 整理一下常见的IE错误
- Mysql Binlog快速遍历搜索记录及binlog数据查看的方法
- 最新最热最实用的15个jQuery插件汇总
- 用Jquery实现滚动新闻
- jQuery实现鼠标单击网页文字后在文本框显示的方法
- 为JQuery EasyUI 表单组件增加焦点切换功能的方法
- Flash图片上传组件 swfupload使用指南
其他
- vue el-select 宽度
- laravel join拓展
- java join改变优先级吗
- String task 动态生成
- IE上 vue 白屏没有进入 路由
- c# networkstream 粘包
- python read 分块读取大文件
- python自动平稳化
- nginx1.21.6集群部署
- spring 包扫描配置
- mybatis 多个or联合使用,有空
- vue3 reactive重置对象
- asyncio 异步写日志
- element-ui和ajax实现一个表格
- pandas 除法精度
- 二维码 解析 js库
- mui加载外部网页 缓慢
- mybatisplus一对一文字替换
- Android中断处理
- vue treeselect 清除