Pytorch如何保存训练好的模型

2023-11-16

0.为什么要保存和加载模型

用数据对模型进行训练后得到了比较理想的模型,但在实际应用的时候不可能每次都先进行训练然后再使用,所以就得先将之前训练好的模型保存下来,然后在需要用到的时候加载一下直接使用。模型的本质是一堆用某种结构存储起来的参数,所以在保存的时候有两种方式,一种方式是直接将整个模型保存下来,之后直接加载整个模型,但这样会比较耗内存;另一种是只保存模型的参数,之后用到的时候再创建一个同样结构的新模型,然后把所保存的参数导入新模型。

1.两种情况的实现方法

(1)只保存模型参数字典(推荐)

#保存
torch.save(the_model.state_dict(), PATH)
#读取
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

(2)保存整个模型

#保存
torch.save(the_model, PATH)
#读取
the_model = torch.load(PATH)

3.只保存模型参数的情况(例子)

pytorch会把模型的参数放在一个字典里面,而我们所要做的就是将这个字典保存,然后再调用。
比如说设计一个单层LSTM的网络,然后进行训练,训练完之后将模型的参数字典进行保存,保存为同文件夹下面的rnn.pt文件:

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
         # 2 for bidirection
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  
        # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        out = self.fc(out)
        return out


rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)

# optimize all cnn parameters
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)  
# the target label is not one-hotted
loss_func = nn.MSELoss()  

for epoch in range(1000):
    output = rnn(train_tensor)  # cnn output`
    loss = loss_func(output, train_labels_tensor)  # cross entropy loss
    optimizer.zero_grad()  # clear gradients for this training step
    loss.backward()  # backpropagation, compute gradients
    optimizer.step()  # apply gradients
    output_sum = output


# 保存模型
torch.save(rnn.state_dict(), 'rnn.pt')

保存完之后利用这个训练完的模型对数据进行处理:

# 测试所保存的模型
m_state_dict = torch.load('rnn.pt')
new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
new_m.load_state_dict(m_state_dict)
predict = new_m(test_tensor)

这里做一下说明,在保存模型的时候rnn.state_dict()表示rnn这个模型的参数字典,在测试所保存的模型时要先将这个参数字典加载一下m_state_dict = torch.load('rnn.pt')

然后再实例化一个LSTM对像,这里要保证传入的参数跟实例化rnn是传入的对象时一样的,即结构相同new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)

下面是给这个新的模型传入之前加载的参数new_m.load_state_dict(m_state_dict)

最后就可以利用这个模型处理数据了predict = new_m(test_tensor)

4.保存整个模型的情况(例子)

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)  # 2 for bidirection
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        # print("output_in=", out.shape)
        # print("fc_in_shape=", out[:, -1, :].shape)
        # Decode the hidden state of the last time step
        # out = torch.cat((out[:, 0, :], out[-1, :, :]), axis=0)
        # out = self.fc(out[:, -1, :])  # 取最后一列为out
        out = self.fc(out)
        return out


rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
print(rnn)


optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)  # optimize all cnn parameters
loss_func = nn.MSELoss()  # the target label is not one-hotted

for epoch in range(1000):
    output = rnn(train_tensor)  # cnn output`
    loss = loss_func(output, train_labels_tensor)  # cross entropy loss
    optimizer.zero_grad()  # clear gradients for this training step
    loss.backward()  # backpropagation, compute gradients
    optimizer.step()  # apply gradients
    output_sum = output


# 保存模型

torch.save(rnn, 'rnn1.pt')

保存完之后利用这个训练完的模型对数据进行处理:

new_m = torch.load('rnn1.pt')
predict = new_m(test_tensor)

参考pytorch的官方文档

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch如何保存训练好的模型 的相关文章

随机推荐

  • uView的组件u-picker 选择器

    网址 https www uviewui com components picker html 需要的是数组中的数组 处理核心是将接口获取回来的数组 赋给一个空数组 然后把这空数组再push到一个空数组里面 this arr1 res da
  • docker的配置,基础用法

    什么是docker docker中的容器 lxc gt libcontainer gt runC OCI OCF OCI Open Container initiative 开放容器倡议 由Linux基金会主导于2015年6月创立 旨在围绕
  • 【Swift】LeedCode整数反转

    Swift LeedCode 拿硬币 由于各大平台的算法题的解法很少有Swift的版本 小编这边将会出个专辑为手撕LeetCode算法题 新手撕算法 请包涵 给你一个 32 位的有符号整数 x 返回将 x 中的数字部分反转后的结果 如果反转
  • 小程序分享及返回上级页面

    分享监听 用户点击右上角分享 onShareAppMessage function res console log res if res from menu return title 邀请赢好礼 path pages member memb
  • 【hadoop报错】(ssh)Connection timed out

    背景 在虚拟机上启动hadoop 报错 Starting namenodes on node1 huike cn Last login Thu Oct 14 22 36 08 EDT 2021 on pts 0 node1 huike cn
  • 为什么单线程的Redis那么快?

    1 Redis单线程的本质 其实 Redis并不是单线程 我们之所以会一直称Redis是单线程 这是因为Redis在处理客户端的读写请求时 只有一个主线程 而在处理以下这些操作时 Redis会fork出其他的子线程来处理 主从数据同步 切片
  • 多元统计分析与R语言练习

    多元考试练习 文章目录 多元考试练习 一 多元线性回归模型 1 建立回归模型 2 逐步筛选 3 最优标准方程 影响最大 4 全局择优法 使用4 2 1版本的R 5 分析 6 由标准化偏回归系数可见 方差分析结果 二 判别分析 1 线性判别
  • python学习心得总结

    21年7月8上午是我第一次接触python这个语言 对于python这个语言之前了解的也并不是很多 也可以说几乎为零 因为我们之前的学习也不考python 所以也没想过去主动学习它 然而当我听老师讲解的时候 首先我发现python这个语言相
  • 使用VS2005下自带的MSSQL 2005 EXPRESS

    VS2005安装后自带一个试用版的SQL2005 EXPRESS版 方便了开发时使用数据库 不用再安装一个sql 2005 怪占用资源的 如何使用 安装后 在开始菜单里出现个sql的菜单组 但是找不到sql server的控制台 习惯用sq
  • 《机器学习实战》——决策树

    本章介绍的决策树算法为ID3算法 Iterative Dichotomiser 3 迭代二叉树3代 主要流程为 根据信息增益找到划分数据的最佳特征 判断划分后每个数据子集是否为同一分类 若是 返回分类结果 若不是 再次划分数据子集 递归 同
  • iOS. Xcode11 dylib封装成framework 图文教程

    Frameworks 制作 Xcode 版本 1 framework是什么 framework是一个层级的目录结构 将一系列可共享的资源 比如动态共享库 nib文件 图形文件 本地化相关文件 头文件 以及相关引用文档 包装成一个包 pack
  • 输入PM2.5的值,判断空气质量

    一个简单的if语句 a int input 请输入PM2 5的值 if 0 lt a lt 35 print 优 elif 35 lt a lt 75 print 良 elif 75 lt a lt 115 print 轻度污染 elif
  • Linux下的文件名空格处理

    转载原文 https blog csdn net michaelzhou224 article details 12708333 解决空格问题的几种方案 1 使用 来替代一个含有空格的文件以及目录 jorncess red black 可以
  • Android开发-Android项目结构

    文章目录 前言 一 Gradle 1 1什么是Gradle 1 2Gradle是一个构建工具 那么为什么要用构建工具 二 项目结构 三 app目录结构 四 res目录结构 总结 前言 Android工程的项目结构比较复杂 在进行Androi
  • AWD简单介绍和搭建AWD平台

    AWD简单介绍和搭建AWD平台 何为AWD 比赛中每个队伍维护多台服务器 服务器中存在多个漏洞 利用漏洞攻击其他队伍可以进行得分 修复漏洞可以避免被其他队伍攻击失分 1 一般分配Web服务器 服务器 多数为Linux 某处存在flag 一般
  • KNN数据分类算法的matlab仿真

    目录 1 算法概述 2 仿真效果 3 MATLAB仿真源码 1 算法概述 KNN的本质是通过距离判断待测样本和已知样本是否相似 待测样本找到与已知样本中与其距离最近的K个样本 对这k个样本 它们大多数属于哪一类别 就把待测样本归为哪一类别
  • [工程编写]cmakelist多版本python环境编写

    问题 最近在写一个工程的时候需要用到python3 但是由于引入了ROS相关的环境 导致希望使用python3的那部分代码一直默认使用ROS中的python2 这样环境就不对了 解决的方法 很顺理成章的想法是为需要python3的那部分代码
  • 注解&反射学习笔记

    1 注解的作用域及使用方式 表示我们的注解可以使用在那些地方 Target value ElementType METHOD ElementType TYPE 表示注解在什么地方有效 RESOUT 源码 lt CLASS 类 lt RUNT
  • IntelliJ Idea 常用快捷键 超实用!

    IntelliJ Idea 常用快捷键 列表 实战终极总结 1 自动代码 常用的有fori sout psvm Tab即可生成循环 System out main方法等boilerplate样板代码 例如要输入for User user u
  • Pytorch如何保存训练好的模型

    0 为什么要保存和加载模型 用数据对模型进行训练后得到了比较理想的模型 但在实际应用的时候不可能每次都先进行训练然后再使用 所以就得先将之前训练好的模型保存下来 然后在需要用到的时候加载一下直接使用 模型的本质是一堆用某种结构存储起来的参数