Pytorch使用DDP加载模型时出现多进程在GPU0上占用过多显存的问题

2023-10-27

使用pytorch DDP(DistributedDataParallel,分布式数据并行)可以进行多卡训练,涉及到模型保存与加载问题时,一般会涉及到以下两种需求:

  1. 将多卡训练的模型保存到磁盘。
  2. 从磁盘加载模型,在多卡上继续训练。

如何无bug且高效的解决以上需求?(假设训练设备为“单机4卡”)

对于需求1,由于DDP在多卡中维护了相同的模型参数(通过在4张GPU上确保模型初始化以及广播相同的梯度来保证4张卡中的模型参数是完全相同的),因此只需要在其中一张卡保存模型即可:

def save_checkpoint(local_rank, ddp_model, path):
    #只在GPU 0 上保存模型
    if local_rank== 0:
        state = {
            'model': ddp_model.module.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        torch.save(state, path)

对于需求2,一般会使用torch.load()方法从磁盘加载文件:

def load_checkpoint(path):
    checkpoint = torch.load(path)
    model = Net()
    model.load_state_dict(checkpoint['model'])
    model = DDP(model, device_ids=[gpu])
    return model

但是此时往往会遇到多进程在GPU0上占用过多显存的问题:

使用nvidia-smi命令:

上图中,在所有使用GPU0的进程中,除了PID为62250的进程外,还存在其他三个进程,而这三个进程还分别使用GPU1\2\3。这三个额外进程在GPU0占用了725MB*3的显存空间,这可能会导致GPU0在训练时出现爆显存的问题。

在DDP中,会为每张卡单独创建一个进程:

上图的情况是正常的,每个进程只会使用与其对应的一张显卡。

该问题出现的原因是:torch.load()的不正确使用。

在pytorch对torch.load()方法的官方文档中,有以下说明:

If map_location is missing, torch.load will first load the module to CPU and then copy each parameter to where it was saved

意思是,如果map_location参数是空的,则torch.load方法会先把模型加载到CPU,然后把模型参数复制到保存它的地方(根据上文,保存模型的位置恰好是GPU 0)。

跑在GPU1上的进程在执行到torch.load方法后,会先加载模型到CPU,之后该进程顺理成章地调用GPU0,把一部分数据复制到GPU0,也就出现了前面图中的问题。

与其说是bug,倒不如说没仔细阅读文档。

两种解决方法方法。

一,将map_location指定为CPU:

def load_checkpoint(path):
    #加载到CPU
    checkpoint = torch.load(path,map_location='cpu')
    model = Net()
    model.load_state_dict(checkpoint['model'])
    model = DDP(model, device_ids=[gpu])
    return model

二,将map_location指定为local_rank对应的GPU:

def load_checkpoint(path):
    #加载到CPU
    checkpoint = torch.load(path,map_location='cuda:{}'.format(local_rank))
    model = Net()
    model.load_state_dict(checkpoint['model'])
    model = DDP(model, device_ids=[gpu])
    return model

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch使用DDP加载模型时出现多进程在GPU0上占用过多显存的问题 的相关文章

随机推荐

  • ieframe.dll修复方法

    ieframe dll文件对一些电脑软件 电脑游戏等程序的正常运行起到关键性作用 对于弹出缺少此类文件的弹窗 用户们很多时候也摸不着头脑 程序明明上次都能正常运行 突然就弹出缺少dll文件的提醒窗口 通过小编此次编辑的文章 用户们将可轻松解
  • 腾讯面经汇总

    网络 tcp超时 客户端断电了 我tcp怎么感知 A 断电操作系统就不会发送FIN 但tcp感知 emmmm send函数返回 1吧 Q 你确定吗 A 尬笑 Q 下去了好好研究研究吧 就让说TCP IP 然后我就说了TCP三次握手 四次挥手
  • Sass 循环语句

    本节我们学习 Sass 中的循环语句 Sass 中的循环语句可以使用 for 指令和 while 指令来实现 for指令 for 指令可以用于循环生成样式 for 指令有两种类型 如下所示 第一种 for i from
  • jperf服务器报告文档,Iperf 简单试验报告

    实验环境 服务器1 hadoop6 CentOS 6 5 X64 PC 自己的 iperf 有windows 版本 服务器 hadoop CentOS 6 5 X64 下载 https iperf fr iperf download php
  • 嵌入式设备上打印输出不及时-----fflush

    嵌入式设备上打印输出不及时的情况遇到过几次 有许多业务或者功能是通过printf函数将一些信息输出给其他应用 或者有些功能模块通过监控日志来做一些判断 如果打印输出不及时可能会有问题 之前写过一个获取驱动中无线帧格式的小程序 就是通过pri
  • 地震学AI模型

    1 地震数据格式 1 1 SAC SAC波形数据是以数据处理为目的的格式 这种格式一般只包含单个台站单个分量或多分量的数据 在SAC用户指南中描述了SAC波形数据输入格式 利用读入的数据 可以进一步做其他处理 如绘制波形图 1 2 SEED
  • 习题2-6 排列 算法竞赛入门经典(C/C++)

    用1 9九个数字组成三个三位数abc def ghi 每个数字恰好使用一次 要求三个数abc def ghi 1 2 3的所有可能 按照 abc def ghi 格式输出所有解 一行为一个解 样例输出 192 384 576 数据量级不大
  • django rest framework系列03-get使用方式基于token基本用户登录状态认证

    1 先看代码后讲解 views部分 from rest framework views import APIView from django http import JsonResponse from API import models f
  • idea 自定义注释 -- 类注释 方法注释

    自定义注释可以按照我们自己喜欢的风格 快速创建注释 废话不多 动起手来 一 在setting界面 根据流程进行设置 1 类注释 设置自定义注释格式 author USER createTime DATE TIME description 2
  • c语言高精度加法

    今天遇到一道题 让我写高精度加法 钻研了一会 写下了代码 include
  • [深入研究4G/5G/6G专题-57]: L3信令控制-6-什么是无线承载DRB Profile

    目录 第1章 什么是DRB Profile 1 1 什么是DRB 1 2 什么是DRB Profile 1 3 DRB Profile的作用 1 4 QCI profile
  • PIM-SM协议初探(一)路由角色选举

    PIM是Protocol Independent Multicast 协议无关组播 的简称 表示可以利用静态路由或者任意单播路由协议 包括RIP OSPF IS IS BGP等 所生成的单播路由表为IP组播提供路由 组播路由与所采用的单播路
  • html输出xml纯文本,将XML转换为纯文本

    我的目标是构建一个引擎 它使用最新的HL7 3 0CDA文档 并使它们与HL7 2 5向后兼容 后者是一个完全不同的野兽 CDA文档是一个XML文件 当与匹配的XSL文件配对时 它会呈现适合最终用户显示的HTML文档 在HL7 2 5中 我
  • “定制化人才” 的悲哀

    这篇博客写得就是自己现阶段的一些感悟 今天看到一个微信公众号的文章推送 标题就是 24岁后 你更应该逼自己系统性成长 只是看到这个标题就很有感触啊 因为还有一个月就24了 但是很迷茫 完全不知道自己的竞争力在哪里 可能唯一的优势大概就是前后
  • REDIS19_zipList压缩列表详解、快递列表 - QuickList、跳表 - SkipList

    文章目录 压缩列表 zipList 快递列表 QuickList 跳表 SkipList 压缩列表 zipList ZipList是一种特殊的 双端链表 由一系列特殊编码的连续内存块组成 可以在任意一端进行压入 弹出操作 并且该操作的时间复
  • CSRF(跨站请求伪造)详细说明

    Cross Site Request Forgery CSRF 中文一般译作跨站请求伪造 经常入选owasp漏洞列表Top10 在当前web漏洞排行中 与XSS和SQL注入并列前三 与前两者相比 CSRF相对来说受到的关注要小很多 但是危害
  • java符号解释大全,太完整了!

    微服务是什么 微服务起源于2005年Peter Rodgers博士在云端运算博览会提出的微Web服务 Micro Web Service 根本思想类似于Unix的管道设计理念 2014年 由Martin Fowler 与 James Lew
  • python中使用pymongo操作mongo

    MongoDB是由C 语言编写的非关系型数据库 是一个基于分布式文件存储的开源数据库系统 其内容存储形式类似JSON对象 它的字段值可以包含其他文档 数组及文档数组 非常灵活 在这一节中 我们就来看看Python 3下MongoDB的存储操
  • Hibernate学习笔记 多表映射

    前面说了Hibernate的单表映射 由于是实体类和数据表之间一对一的映射 所以比较简单 现在就来说说多表映射 这需要涉及到多个实体类和数据表之间的关系 因此稍微复杂一点 建立实体类 我建立了两个实体类 一个作者类 一个文章类 其他方法都忽
  • Pytorch使用DDP加载模型时出现多进程在GPU0上占用过多显存的问题

    使用pytorch DDP DistributedDataParallel 分布式数据并行 可以进行多卡训练 涉及到模型保存与加载问题时 一般会涉及到以下两种需求 将多卡训练的模型保存到磁盘 从磁盘加载模型 在多卡上继续训练 如何无bug且