pytorch CPU与GPU模型参数相互加载

2023-11-12


1. 模型保存以及加载方法

# 直接保存模型 (参数 + 网络结构)
torch.save(model, '/path/to/save')
model = torch.load('/path/to/load')
# 只保存参数 (推荐)
torch.save(model.state_dict(), '/path/to/save')
model = NET()
mode.load_state_dict(torch.load('/path/to/load'))
# 保存参数、优化器、epoch
state = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'epoch': epoch
}
torch.save(state, '/path/to/save')
checkpoint = torch.load('/path/to/load')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

2. 单 GPU 和 单 CPU 参数-模型相互加载

## CPU->CPU OR GPU->GPU 直接加载
model.load_state_dict(torch.load('/path/to/load'))

## GPU->CPU (参数->模型)
state_dict = torch.load('/path/to/load', map_location=lambda storage, loc: storage)
state_dict = torch.load('/path/to/load', map_location='cpu')

model.load_state_dict(states_dict)

## CPU->GPU (参数->模型)
state_dict = torch.load('/path/to/load', map_location=lambda storage, loc: storage.cuda)
state_dict = torch.load('/path/to/load', map_location='cuda:0')
### 指定GPU
state_dict = torch.load('/path/to/load', map_location=lambda storage, loc: storage.cuda(1))
state_dict = torch.load('/path/to/load', map_location='cuda:1')

3. 多 GPU 模型-参数

## 模型 + 参数
torch.save(model.module, '/path/to/save') # 多了个module
## 参数
torch.save(model.module.state_dict(), '/path/to/save') # 多了个module

4. 单 GPU or CPU 模型加载多 GPU 参数

## 多gpu上保存的模型在参数名前多加了一个module.前缀
device = torch.device('cpu') # cup 模型
# device = torch.device('cuda:0') # gpu 模型
model = NET().to(device)
state_dict = torch.load('/path/to/load', map_location=device)
state_dict_new = {}
for k, v in state_dict.items():
    new_k = k[7:] # 去掉键名的前七个字母,即'module.'
    state_dict_new[new_k] = v

model.load_state_dict(state_dict_new)

5. 单 GPU or CPU 加载 多GPU模型+参数

model_cpu = NET().to('cpu')
model_gpu = NET().to('cuda:0')

pretrained_model = torch.load('/path/to/load') # 模型+参数

pretrained_dict = pretrained_model.module.state_dict() # 提取参数

model_cpu.load_state_dict(pretrained_dict)
model_gpu.load_state_dicr(pretrained_dicr)

6. 多 GPU 加载 多GPU参数

model = NET().to('cuda:0')
model = torch.nn.DataParallel(model, device_ids=[0, 1])
state_dict = torch.load('/path/to/load')
model.load_state_dict(state_dict)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch CPU与GPU模型参数相互加载 的相关文章

  • 在 Django 1.6 中结合 DetailView 和 CreateView

    我有 2 个独立的模型 帖子和评论 我使用 DetailView 来显示帖子内容 并且我想使用 CreateView 在同一页面上显示评论创建表单 最干净的方法是什么 唯一想到的是使用自定义视图 它既获取对象又处理评论表单 但这看起来太脏了
  • 将 3d NumPy 数组重塑为 2d NumPy 数组时遇到问题

    我正在研究图像处理问题 我的数据以 3 维 NumPy 数组的形式呈现 其中 x y z 条目是图像 z 的 x y 像素 数值强度值 有 100000 张图像 每张图像为 25x25 因此 数据矩阵的大小为 25x25x10000 我试图
  • 检查时间戳列是否在另一个数据帧的日期范围内

    我有一个数据框 df A 有两列 amin 和 amax 这是一组时间范围 我的目标是查找 df B 中的列是否位于 df A amin 和 amax 列中范围的任何行之间 df A amin amax amin amax 0 2016 0
  • 为什么any (True for ... if cond) 比any (cond for ...) 快得多?

    检查列表是否包含奇数的两种类似方法 any x 2 for x in a any True for x in a if x 2 计时结果与a 0 10000000 每次尝试五次 次数以秒为单位 0 60 0 60 0 60 0 61 0 6
  • 如何将多项式拟合到带有误差线的数据

    我目前正在使用 numpy polyfit x y deg 将多项式拟合到实验数据 然而 我想拟合一个基于点误差使用加权的多项式 我已经发现scipy curve fit http docs scipy org doc scipy refe
  • 在 Windows 上的 python2.5 上安装 Openpyxl

    我努力了easy install install openpyxl and python setup install 两者都失败了 我也尝试过easy install openpyxl并再次失败 我包括了我得到的输出 当我尝试时easy i
  • Python int和float在64位系统中的内存消耗

    我正在 Python 3 4 的 64 位系统中尝试以下代码 以了解不同原始数据类型的内存消耗 import sys print sys getsizeof 45 prints 28 print sys getsizeof 45 2 pri
  • 当类的任何属性被修改时,类如何运行某些函数?

    是否有一些通用方法可以让类在以下情况下运行函数 any它的属性被修改了吗 我想知道是否可以运行某些子进程来监视类的更改 但也许有一种方法可以继承class并修改一些on change函数是 Python 类的一部分 有点像默认的 repr
  • numpy 中用最少内存对上三角元素求和的最快方法

    我需要进行此类求和i
  • argparse - 禁用相同参数的出现

    我正在尝试使用 argparse 禁用一个命令行中出现相同的参数 python3 argument1 something argument2 argument1 something else 这意味着这应该会引发错误 因为 argument
  • 使用 Pytest 捕获 SystemExit 消息

    我正在使用 pytest 编写测试 我遇到了一些函数抛出异常的情况SystemExit如果输入错误 终端上会显示一些错误消息 我想为以下情况编写测试SystemExit抛出并验证输出错误消息中是否有特定字符串 这是代码 def test v
  • Pandas 使用 NaN 进行数据透视或重塑数据框

    我有这个数据框 我需要根据以下数据进行旋转或重塑frame col df frame 0 0 1 1 2 2 3 0 4 1 5 2 pvol 0 nan 1 nan 2 nan 3 23 1 4 24 3 5 25 6 vvol 0 10
  • Pygame 旋转射击

    我和几个朋友一直在编写一种有趣的新射击机制 为了让它发挥作用 我们需要朝玩家面对的方向射击 Sprite 正在使用 Pygame Transform Rotate 进行旋转 我们怎样才能找到一个角度 然后朝那个方向发射子弹呢 这是我们的精灵
  • 如何更改Python中的全局变量[重复]

    这个问题在这里已经有答案了 我正在尝试更改程序中的变量 我在程序开始时声明了一个全局变量 我想在程序中的不同函数中更改该变量 我可以通过再次声明函数内的变量来做到这一点 但我想知道是否有更好的方法来做到这一点 下面是一些测试代码来解释我的意
  • Django populate() 不可重入

    当我尝试在生产环境中加载 Django 应用程序时 我不断收到此消息 我尝试了所有的 stackoverflow 答案 但没有任何解决办法 任何其他想法 我使用的是 Django 1 5 2 和 Apache Traceback most
  • 无法从 celery 信号连接到 celery 任务?

    我正在尝试连接task2 from task success signal from celery signals import task success from celery import Celery app Celery app t
  • 在IPython笔记本中自动播放声音

    我经常在 IPython 笔记本中运行长时间运行的单元 我希望笔记本在单元完成执行时自动发出蜂鸣声或播放声音 有没有办法在 iPython 笔记本中执行此操作 或者我可以在单元格末尾放置一些命令来自动播放声音 我正在使用 Chrome 如果
  • 连接 Flask Socket.IO Server 和 Flutter

    基本上 我有一个套接字 io 烧瓶代码 import cv2 import numpy as np from flask import Flask render template from flask socketio import Soc
  • 类型错误:对于仅使用浮点数的函数,返回数组必须是 ArrayType

    这个实在是难倒我了 我有一个计算单词权重的函数 我已经确认 a 和 b 局部变量都是 float 类型 def word weight term a term freq term print a type a b idf term prin
  • 使用自定义层运行 Keras 模型时出现问题

    我目前正在攻读学士学位论文FIIT STU https www fiit stuba sk en html page id 749 其主要目标是尝试复制和验证以下结果study http arxiv org abs 2006 00885 这

随机推荐

  • C/C++ 报错提示 “表达式必须包含类类型” 与 “不可访问”

    今天给大家分享两个常见的错误 定义对象 调用函数 时提示 表达式必须包含类类型 的报错 对象调用函数时提示 不可访问 的报错 一 表达式必须包含类类型 这种报错会出现在两种情况 类没有数据成员时 使用类定义对象时带括号了 定义类时以指针方式
  • MySQL重装——Database initialization failed错误处理

    卸载MySQL 笔者由于跟着网上的教程将MySQL安装到了C盘 忘记了可以走更改路径这条路 在卸载MySQL的路上一去不复返 试过网上诸多重装方案 大体均为以下步骤 控制面板卸载MySQL 删除注册表 删除ProgramData Appli
  • 导出文件:window.open()

    导出文件 window open globalBus emit loading const Download http window location host DI activity orderExcel actId this actId
  • Python-ElasticSearch客户端的封装(聚合查询、统计查询、全量数据)

    目录 ES Python客户端介绍 封装代码 测试代码 参考 ES Python客户端介绍 官方提供了两个客户端elasticsearch elasticsearch dsl pip install elasticsearch pip in
  • Flink1.13.0 + Hudi 0.11.1 + Hive2.1.1 + presto0.273.3 + yanagishima 18.0

    摘要 flink1 13 0 整合 Hudi 0 11 1 通过FlinkSQL程序 FlinkSQL命令行对Hudi的MOR及COW进行批量写 流式写 流式读取 批量读取 通过flink sql cdc flink sql kafka f
  • CSDN周赛65期简要题解

    最近几期周赛里 貌似 Python 又变成 C 站的亲儿子了 输入形式是列表还不过瘾 现在输出形式也要求是列表 而且是连一个逗号 空格 中括号都不能少的 Python 标准列表形式 虽然对 Python 来说是信手拈来 但总要考虑一下其他编
  • Node.js 使用express搭建后台服务器 ( 进阶篇 )

    上篇文章我们介绍了利用express微服务搭建简单的后台服务器以及中间件 今天我们把模块化的思想注入 利用路由分别管理 暴露API接口与前端交互等等 我们先跑起来服务 let express require express 引入expres
  • Keras深度学习资料

    https cnbeining github io deep learning with python cn 1 introduction ch1 welcome html
  • SpringBoot拦截器和动态代理有什么区别?

    在 Spring Boot 中 拦截器和动态代理都是用来实现功能增强的 所以在很多时候 有人会认为拦截器的底层是通过动态代理实现的 所以本文就来盘点一下他们两的区别 以及拦截器的底层实现 1 拦截器 拦截器 Interceptor 准确来说
  • 获取google chrome浏览器的安装位置

    今天 要获取chrome exe安装的路径 发现window XP 和 win 7 的用户路径是一样的 win7 C Users Administrator AppData Local Google Chrome Application 这
  • 2021-1-28Linux学习纪要

    linux 目录结构 bin 存放经常使用的指令 sbin 存放的是系统管理员使用的系统管理程序 home 存放普通用户的主目录 在linux中每个用户都有一个自己的目录 root 超级用户的主目录 boot 存放linux启动核心文件 包
  • 基于样本的性能分析

    基于样本的性能分析是一种性能优化技术 通过定期采集程序在运行时的信息 样本 来识别程序的热点区域或常用路径 这与基于插桩的分析相对 后者通过插入额外的代码来记录每一次函数或代码块的执行 基于样本的方法的主要优点是它对程序的性能干扰较小 但可
  • 类文件具有错误的版本号59.0应为52.0

    很笨的问题 运行 bat批处理文件的时候用管理员权限运行
  • shiro从1.6.0升级到1.7.1版本,请求路径中带有中文接口报400

    由于shiro1 6 0版本出现了安全漏洞 于是进行了版本的升级 升级到1 71 版本 但遇到了以下问题 1 访问某个接口的时候 返回状态码400 invaild request 2 访问路径为XX XXX params XXX包含中文 找
  • java---自动拆装箱

    一 什么是装箱 什么是拆箱 将一个值封装起来就是装箱 就是将一个基本类型转换为一个封装类 否则就是拆箱 而在从Java SE5开始就提供了自动装箱的特性 二 自动如何实现 来一个小栗子 public class Main public st
  • kettle循环取结果集进行处理方法一(使用js)

    需求 循环取结果集中的一行 再根据单个结果进行处理 此处实例 从test库取id字段结果集 存储 id 2 x id 到set value表中 此处使用js脚本 方法二 不使用js https blog csdn net weixin 44
  • 生成doc文件,并压缩进文件夹

    导出业务人员日志 SuppressWarnings null RequestMapping value exportEsComDailyList public void exportEsComDailyList RequestParam n
  • 提升开发效率的必备技能:Spring集成Mybatis和PageHelper详解

    目录 引言 一 Spring集成MyBatis 1 1 pom依赖 1 2 配置文件 1 3 Spring整合MyBatis 1 3 1 配置自动扫描JavaBean 1 3 2 配置数据源 1 3 3 配置session工厂 1 3 4
  • js 数组

    1 数组的创建 var arrayObj new Array 创建一个数组 var arrayObj new Array size 创建一个数组并指定长度 注意不是上限 是长度 var arrayObj new Array element0
  • pytorch CPU与GPU模型参数相互加载

    文章目录 1 模型保存以及加载方法 2 单 GPU 和 单 CPU 参数 模型相互加载 3 多 GPU 模型 参数 4 单 GPU or CPU 模型加载多 GPU 参数 5 单 GPU or CPU 加载 多GPU模型 参数 6 多 GP