PyTorch模型保存与加载

2023-05-16

  1. torch.save:保存序列化的对象到磁盘,使用了Pythonpickle进行序列化,模型、张量、所有对象的字典。
  2. torch.load:使用了pickle的unpacking将pickled的对象反序列化到内存中。
  3. torch.nn.Module.load_state_dict:使用反序列化的state_dict加载模型的参数字典。

state_dict 是一个Python字典,将每一层映射成它的参数张量。注意只有带有可学习参数的层(卷积层、全连接层等),以及注册的缓存(batchnorm的运行平均值)在state_dict 中才有记录。state_dict同样包含优化器对象,存储了优化器的状态,所使用到的超参数。

一个简单的例子

# 定义模型
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 初始化模型
model = TheModelClass()

# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 打印模型的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# 打印优化器的 state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

输出:

Model's state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])

Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

保存/加载 state_dict(推荐)

保存:

torch.save(model.state_dict(), PATH)

加载:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

要注意这个细节,如果使用nn.DataParallel在一台电脑上使用了多个GPU,那么加载模型的时候也必须先进行nn.DataParallel

保存模型的推理过程的时候,只需要保存模型训练好的参数,使用torch.save()保存state_dict,能够方便模型的加载。因此推荐使用这种方式进行模型保存。

记住一定要使用model.eval()来固定dropout和归一化层,否则每次推理会生成不同的结果。

注意,load_state_dict()需要传入字典对象,因此需要先反序列化state_dict再传入load_state_dict()

保存/加载整个模型

保存:

torch.save(model, PATH)

加载:

# 模型类必须在别的地方定义
model = torch.load(PATH)
model.eval()

这种保存/加载模型的过程使用了最直观的语法,所用代码量少。这使用Python的pickle保存所有模块。这种方法的缺点是,保存模型的时候,序列化的数据被绑定到了特定的类和确切的目录。这是因为pickle不保存模型类本身,而是保存这个类的路径,并且在加载的时候会使用。因此,当在其他项目里使用或者重构的时候,加载模型的时候会出错。

一般来说,PyTorch的模型以.pt或者.pth文件格式保存。

一定要记住在评估模式的时候调用model.eval()来固定dropout和批次归一化。否则会产生不一致的推理结果。

保存加载用于推理的常规Checkpoint/或继续训练

保存:

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

加载:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - 或者 -
model.train()

在保存用于推理或者继续训练的常规检查点的时候,除了模型的state_dict之外,还必须保存其他参数。保存优化器的state_dict也非常重要,因为它包含了模型在训练时候优化器的缓存和参数。除此之外,还可以保存停止训练时epoch数,最新的模型损失,额外的torch.nn.Embedding层等。

要保存多个组件,则将它们放到一个字典中,然后使用torch.save()序列化这个字典。一般来说,使用.tar文件格式来保存这些检查点。

加载各个组件,首先初始化模型和优化器,然后使用torch.load()加载保存的字典,然后可以直接查询字典中的值来获取保存的组件。

同样,评估模型的时候一定不要忘了调用model.eval()

保存多个模型到一个文件

保存:

torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)

加载:

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - 或者 -
modelA.train()
modelB.train()

保存的模型包含多个torch.nn.Modules时,比如GAN,一个序列-序列模型,或者组合模型,使用与保存常规检查点的方式来保存模型。也就是说,保存每个模型的state_dict和对应的优化器到一个字典中。我们可以保存任何能帮助我们继续训练的东西到这个字典中。

使用其他模型来预热当前模型

保存:

torch.save(modelA.state_dict(), PATH)

加载:

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

在迁移学习或者训练新的复杂模型时,加载部分模型是很常见的。利用经过训练的参数,即使只有少数参数可用,也将有助于预热训练过程,并且使模型更快收敛。

在加载部分模型参数进行预训练的时候,很可能会碰到键不匹配的情况(模型权重都是按键值对的形式保存并加载回来的)。因此,无论是缺少键还是多出键的情况,都可以通过在load_state_dict()函数中设定strict参数为False来忽略不匹配的键。

如果想将某一层的参数加载到其他层,但是有些键不匹配,那么修改state_dict中参数的key可以解决这个问题。

跨设备保存与加载模型

GPU上保存,CPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

当在CPU上加载一个GPU上训练的模型时,在torch.load()中指定map_location=torch.device('cpu'),此时,map_location动态地将tensors的底层存储重新映射到CPU设备上。

上述代码只有在模型是在一块GPU上训练时才有效,如果模型在多个GPU上训练,那么在CPU上加载时,会得到类似如下错误:

KeyError: ‘unexpected key “module.conv1.weight” in state_dict’

原因是在使用多GPU训练并保存模型时,模型的参数名都带上了module前缀,因此可以在加载模型时,把key中的这个前缀去掉:

# 原始通过DataParallel保存的文件
state_dict = torch.load('myfile.pth.tar')
# 创建一个不包含`module.`的新OrderedDict
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # 去掉 `module.`
    new_state_dict[name] = v
# 加载参数
model.load_state_dict(new_state_dict)

GPU上保存,GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# 往模型中输入数据的时候不要忘记在任意tensor上调用input = input.to(device)

在把GPU上训练的模型加载到GPU上时,只需要使用model.to(torch.devie('cuda'))将初始化的模型转换为CUDA优化模型。同时确保在模型所有的输入上使用.to(torch.device('cuda'))。注意,调用my_tensor.to(device)会返回一份在GPU上的my_tensor的拷贝。不会覆盖原本的my_tensor,因此要记得手动将tensor重写:my_tensor = my_tensor.to(torch.device('cuda'))

CPU上保存,GPU上加载

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # 选择希望使用的GPU
model.to(device)

保存torch.nn.DataParallel模型

保存:

torch.save(model.module.state_dict(), PATH)

扫码关注微信公众号:机器工匠,阅读更多文章。
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch模型保存与加载 的相关文章

  • 关于 Pytorch 中的奇数图像尺寸

    因此 我目前正在构建一个 2 通道 也称为双通道 卷积神经网络 用于测量 2 个 二进制 图像之间的相似度 我遇到的问题如下 我的输入图像为 40 x 50 经过 1 个卷积层和 1 个池化层 例如 后 输出大小为 18 x 23 那么如何
  • 使用 torch.stack()

    t1 torch tensor 1 2 3 t2 torch tensor 4 5 6 t3 torch tensor 7 8 9 torch stack t1 t2 t3 dim 1 在实现 torch stack 时 我无法理解如何对不
  • Pytorch:获取最终层的正确尺寸

    Pytorch 新手来了 我正在尝试微调 VGG16 模型来预测 3 个不同的类别 我的部分工作涉及将 FC 层转换为 CONV 层 但是 我的预测值不会落在 0 到 2 3 个类别 之间 有人可以向我指出有关如何计算最后一层的正确尺寸的好
  • 如何计算图像数据集中 RGB 值的 3x3 协方差矩阵?

    我需要计算图像数据集中 RGB 值的协方差矩阵 然后将 Cholesky 分解应用于最终结果 RGB 值的协方差矩阵是 3x3 矩阵 M 其中 M i i 是通道 i 的方差 M i j 是通道 i 和 j 之间的协方差 最终结果应该是这样
  • 如何在 PyTorch 数据加载器中将 RGB 图像转换为灰度图像?

    我已经从 MNIST 数据集中下载了一些示例图像 jpg格式 现在我正在加载这些图像来测试我的预训练模型 transforms to apply to the data trans transforms Compose transforms
  • 如何在 PyTorch 中的特定新维度中重复张量

    如果我有一个张量A有形状 M N 我想重复张量 K 次 以便结果B有形状 M K N 和每片B k 应该具有相同的数据A 这是没有 for 循环的最佳实践 K可能在其他维度 torch repeat interleave and tenso
  • 在 Pytorch 中获取负片(倒置)图像

    我想直接从数据加载器获取图像的负片并将其作为张量提供 有我可以使用的库吗 我试过火炬transforms并没有找到任何 不要费力 只需使用255 image它会给你一个负面的形象 试试吧
  • max_length、填充和截断参数在 HuggingFace 的 BertTokenizerFast.from_pretrained('bert-base-uncased') 中如何工作?

    我正在处理文本分类问题 我想使用 BERT 模型作为基础 然后使用密集层 我想知道这 3 个参数是如何工作的 例如 如果我有 3 个句子 My name is slim shade and I am an aspiring AI Engin
  • 如何检查 PyTorch 是否正在使用 GPU?

    如何检查 PyTorch 是否正在使用 GPU 这nvidia smi命令可以检测 GPU 活动 但我想直接从 Python 脚本内部检查它 这些功能应该有助于 gt gt gt import torch gt gt gt torch cu
  • PyTorch - 参数不变

    为了了解 pytorch 的工作原理 我尝试对多元正态分布中的一些参数进行最大似然估计 然而 它似乎不适用于任何协方差相关的参数 所以我的问题是 为什么这段代码不起作用 import torch def make covariance ma
  • PoseWarping:如何矢量化此 for 循环(z 缓冲区)

    我正在尝试使用地面真实深度图 姿势信息和相机矩阵将帧从视图 1 扭曲到视图 2 我已经能够删除大部分 for 循环并将其矢量化 除了一个 for 循环 扭曲时 由于遮挡 视图 1 中的多个像素可能会映射到视图 2 中的单个位置 在这种情况下
  • pytorch通过易失性变量反向传播错误

    我试图通过多次向后传递迭代来运行它并在每个步骤更新输入 从而最小化相对于某个目标的一些输入 第一遍运行成功 但在第二遍时出现以下错误 RuntimeError element 0 of variables tuple is volatile
  • Pytorch:了解 nn.Module 类内部如何工作

    一般来说 一个nn Module可以由子类继承 如下所示 def init weights m if type m nn Linear torch nn init xavier uniform m weight class LinearRe
  • 为什么测试时一定要用DataParallel?

    在GPU上训练 num gpus设置为1 device ids list range num gpus model NestedUNet opt num channel 2 to device model nn DataParallel m
  • PyTorch 中复数矩阵的行列式

    有没有办法在 PyTorch 中计算复矩阵的行列式 torch det未针对 ComplexFloat 实现 不幸的是 目前尚未实施 一种方法是实现您自己的版本或简单地使用np linalg det 这是一个简短的函数 它计算我使用 LU
  • 使 CUDA 内存不足

    我正在尝试训练网络 但我明白了 我将批量大小设置为 300 并收到此错误 但即使我将其减少到 100 我仍然收到此错误 更令人沮丧的是 在 1200 个图像上运行 10 epoch 大约需要 40 分钟 有什么建议吗 错了 我怎样才能加快这
  • PyTorch LSTM:运行时错误:无效参数 0:张量的大小必须匹配,维度 0 除外。维度 1 为 1219 和 440

    我有一个基本的 PyTorch LSTM import torch nn as nn import torch nn functional as F class BaselineLSTM nn Module def init self su
  • 如何计算 CNN 第一个线性层的维度

    目前 我正在使用 CNN 其中附加了一个完全连接的层 并且我正在使用尺寸为 32x32 的 3 通道图像 我想知道是否有一个一致的公式可以用来计算第一个线性层的输入尺寸和最后一个卷积 最大池层的输入 我希望能够计算第一个线性层的尺寸 仅给出
  • 如何有效地对一个数组中某个值在另一个数组中的位置出现的次数求和

    我正在寻找一种有效的 for 循环 避免解决方案来解决我遇到的数组相关问题 我想使用一个巨大的一维数组 A gt size 250 000 用于一维索引的 0 到 40 之间的值 以及用于第二维索引的具有 0 到 9995 之间的值的相同大
  • 如何使用Python计算多类分割任务的dice系数?

    我想知道如何计算多类分割的骰子系数 这是计算二元分割任务的骰子系数的脚本 如何循环每个类并计算每个类的骰子 先感谢您 import numpy def dice coeff im1 im2 empty score 1 0 im1 numpy

随机推荐

  • Windows 11下载

    Windows 11是微软于2021年推出的Windows NT系列操作系统 xff0c 为Windows 10的后继者 正式版本于2021年10月5日发行 xff0c 并开放给符合条件的Windows 10设备通过Windows Upda
  • docker容器安装图形桌面

    文章目录 视频教程版本信息创建一个CONTAINERubuntu官方国内源docker镜像unminimize中文环境设置中文环境 安装安装TigerVNC Server安装 xfce4精简版本 配置设置vnc密码 vnc xstartup
  • ubuntu官方国内源

    背景 之前我一直在使用中科大的源 xff0c 还是挺快的 一直也没有感觉有什么问题 直到最近在折腾vnc xff0c 发现中科大的源有一些包会404 xff0c 安装不了 而我在vmware中的正好是默认的cn archive ubuntu
  • mame新版ROM下载网站推荐

    网站地址 https www retroroms info index php 中文插件安装 浏览器插件 https www tampermonkey net UP主自己写的脚本 已经失效 https gitee com lxyoucan
  • RuoYi若依打包发布与部署

    上一节我们已经讲过了如果搭建开发环境 xff0c 那么如果代码写完了 xff0c 如何打包发布 部署到生产环境呢 xff1f RuoYi开发实战 搭建开发环境 https blog csdn net lxyoucan article det
  • vscode设置Prettier为默认格式化插件

    1 目的 xff1a ctrl 43 s保存 xff0c 自动格式化文档 2 所需插件Prettier 3 操作步骤 先打开vscode软件 xff0c 左下角点击设置 gt 打开设置 gt 在右上方有一个搜索框 先设定自动保存文件 xff
  • ASUS X415安装系统找不到硬盘解决办法

    同事让我帮忙安装系统 xff0c 笔记本电脑型号是ASUS X415 原本以为是手到擒来的事情 xff0c 结果我在上面还是消耗了不少时间 现象 老毛桃PE 无法识别到硬盘 微PE可以识别到硬盘 xff0c 但是系统安装以后 xff0c 无
  • archlinux中navicat无法使用fcitx5输入法

    现象 archlinux中navicat无法使用fcitx5输入法 而我在ubuntu中使用navicat调用fcitx输入法是可以正常使用的 在网上搜索了很久 xff0c 这方面的文章比较少 而我的其他程序输入法又是正常的 解决办法 参考
  • JetBrains Gateway IDEA远程开发

    为什么进行远程开发 xff1f 无论身处何处数秒内连接至远程环境 充分利用远程计算机的强大功能 在任何笔记本电脑上都可以轻松工作 xff0c 无论其性能如何 借助远程计算机的计算资源 xff0c 充分利用最大规模的数据集和代码库 在远程服务
  • ubuntu 22.04安装nvm

    执行安装脚本 span class token function sudo span span class token function apt span span class token function install span spa
  • 手推DNN,CNN池化层,卷积层反向传播

    反向传播算法是神经网络中用来学习的算法 xff0c 从网络的输出一直往输出方向计算梯度来更新网络参数 xff0c 达到学习的目的 xff0c 而因为其传播方向与网络的推理方向相反 xff0c 因此成为反向传播 神经网络有很多种 xff0c
  • 软件架构概念和面向服务的架构

    摘要 软件架构作为软件开发过程的一个重要组成部分 xff0c 有着各种各样的方法和路线图 xff0c 它们都有一些共同的原则 基于架构的方法作为控制系统构建和演化复杂性的一种手段得到了推广 引言 在计算机历史中 xff0c 软件变得越来越复
  • 初识强化学习,什么是强化学习?

    相信很多人都听过 机器学习 和 深度学习 但是听过 强化学习 的人可能没有那么多 那么 什么是强化学习呢 强化学习是机器学习的一个子领域 它可以随着时间的推移自动学习到最优的策略 在我们不断变化的纷繁复杂的世界里 从更广的角度来看 即使是单
  • 强化学习形式与关系

    在强化学习中有这么几个术语 智能体 Agent 环境 Environment 动作 Action 奖励 Reward 状态 State 有些地方称作观察 Observation 奖励 Reward 在强化学习中 奖励是一个标量 它是从环境中
  • 多层网络和反向传播笔记

    在我之前的博客中讲到了感知器 xff08 感知器 xff09 xff0c 它是用于线性可分模式分类的最简单的神经网络模型 xff0c 单个感知器只能表示线性的决策面 xff0c 而反向传播算法所学习的多层网络能够表示种类繁多的非线性曲面 对
  • 在Kaggle手写数字数据集上使用Spark MLlib的朴素贝叶斯模型进行手写数字识别

    昨天我在Kaggle上下载了一份用于手写数字识别的数据集 xff0c 想通过最近学习到的一些方法来训练一个模型进行手写数字识别 这些数据集是从28 28像素大小的手写数字灰度图像中得来 xff0c 其中训练数据第一个元素是具体的手写数字 x
  • Ros使用自定义数据通讯无法收到消息的分析和解决

    nbsp 在实际的开发中 和别的模块定义了自定义的 数据类型 比如 userMsg msg文件 Header header int32 nState string strImageName string strYamlName 报错和原因
  • 在Kaggle手写数字数据集上使用Spark MLlib的RandomForest进行手写数字识别

    昨天我使用Spark MLlib的朴素贝叶斯进行手写数字识别 xff0c 准确率在0 83左右 xff0c 今天使用了RandomForest来训练模型 xff0c 并进行了参数调优 首先来说说RandomForest 训练分类器时使用到的
  • 遇见AI,从Java到数据挖掘。

    在上小学的时候就听说过AI xff0c 人工智能 xff0c 那个时候我对人工智能的感受都来自于各类影视作品 xff0c 类人的外表 xff0c 能听说读写 xff0c 有情感 xff0c 会思考 所以那个时候的我将人工智能想象成和人类相似
  • PyTorch模型保存与加载

    torch save xff1a 保存序列化的对象到磁盘 xff0c 使用了Python的pickle进行序列化 xff0c 模型 张量 所有对象的字典 torch load xff1a 使用了pickle的unpacking将pickle