Torch 1.9.1 DDP 并行优化与多模块调用问题

2023-11-06

DDP 基础实现

由于 DataParallel (DP) 采取的是多线程并行,出于其特性,会造成通信瓶颈 (GIL 限制),因此更高效的方式是使用 DistributedDataParallel 实现更高效的 GPU 使用。DDP 相关基础实现参考此处文章,亲测可以使用。

:目前使用 1.7+ pytorch nccl 初始化 DDP 会报错,亲测也有该问题,因此建议使用 ‘gloo’,虽然相对速度可能较慢。

多模块调用问题

假设考虑如下问题:

import torch

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

import torch.optim as optim
import torch.nn as nn

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1, type=int)
FLAGS = parser.parse_args()
local_rank = FLAGS.local_rank

torch.cuda.set_device(local_rank)
dist.init_process_group(backend='gloo')

device = torch.device("cuda", local_rank)


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.device = device

        self.nn1 = nn.Linear(10, 20)
        self.nn2 = nn.Linear(20, 10)

    def forward(self, x):
        mid = self.nn1(x)
        model = self.nn2(mid)

        return model

    def middle(self, x):
        return self.nn1(x)


model = Net().to(device)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)

x = torch.randn(20, 10).to(local_rank)
outputs = model(x)
mid = model.middle(x)

labels = torch.randn(20, 10).to(local_rank)
loss_fn = nn.MSELoss()
loss_fn(outputs, labels).backward()

optimizer = optim.SGD(model.parameters(), lr=0.001)
optimizer.step()

如果没有 model_ddp = DDP(model, device_ids=[local_rank], output_device=local_rank) 的操作,上述程序是个简单例子。注意到此处 Net 模型有一个 intrinsic forwar() 函数 和 其他函数 middle(),在调用 mid = model.middle(x) 会报错,因为 model 转换为 DDP 形式之后只会执行 forward()。

考虑如下解决方式:

DDP.module.func()

查看 DDP 实现源码可以发现,DDP 可以通过模块引用从而调用 其他函数 middle()。将

mid = model.middle(x)

改为

mid = model.module.middle(x)

可以找到 middle() 函数,注意此处 mid 是非 DDP 形式的。

并行非并行分离

类似 StyleGAN2, 我们可以考虑创建并行 model_ddp() 模型和 model() 非并行模型,由于其他函数一般不参与模型学习,而是中间状态输出,因此这种情况下我们可以使用双模型。如下:

model = Net().to(device)
model_ddp = DDP(model, device_ids=[local_rank], output_device=local_rank)

forward() 条件引入 middle

由于问题是 DDP 只会执行 forward(), 因此我们可以考虑将 middle 进行条件引入:

def forward(self, x, run_middle=False):
    if run_middle:
        return self.middle(x)
    mid = self.nn1(x)
    model = self.nn2(mid)

    return model

此为一种最为通用的策略。

上述三种策略根据需求可以结合使用。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Torch 1.9.1 DDP 并行优化与多模块调用问题 的相关文章

  • 如果两点之间的距离低于某个阈值,则从列表中删除点

    我有一个点列表 只有当它们之间的距离大于某个阈值时 我才想保留列表中的点 因此 从第一个点开始 如果第一个点和第二个点之间的距离小于阈值 那么我将删除第二个点 然后计算第一个点和第三个点之间的距离 如果该距离小于阈值 则比较第一点和第四点
  • 使用 python requests 模块时出现 HTTP 503 错误

    我正在尝试发出 HTTP 请求 但当前可以从 Firefox 浏览器访问的网站响应 503 错误 代码本身非常简单 在网上搜索一番后我添加了user Agent请求参数 但也没有帮助 有人能解释一下如何消除这个 503 错误吗 顺便说一句
  • 安装了 32 位的 Python,显示为 64 位

    我需要运行 32 位版本的 Python 我认为这就是我在我的机器上运行的 因为这是我下载的安装程序 当我重新运行安装程序时 它会将当前安装的 Python 版本称为 Python 3 5 32 位 然而当我跑步时platform arch
  • 将html数据解析成python列表进行操作

    我正在尝试读取 html 网站并提取其数据 例如 我想查看公司过去 5 年的 EPS 每股收益 基本上 我可以读入它 并且可以使用 BeautifulSoup 或 html2text 创建一个巨大的文本块 然后我想搜索该文件 我一直在使用
  • Python 中的舍入浮点问题

    我遇到了 np round np around 的问题 它没有正确舍入 我无法包含代码 因为当我手动设置值 而不是使用我的数据 时 返回有效 但这是输出 In 177 a Out 177 0 0099999998 In 178 np rou
  • 处理 Python 行为测试框架中的异常

    我一直在考虑从鼻子转向行为测试 摩卡 柴等已经宠坏了我 到目前为止一切都很好 但除了以下之外 我似乎无法找出任何测试异常的方法 then It throws a KeyError exception def step impl contex
  • Python getstatusoutput 替换不返回完整输出

    我发现了这个很棒的替代品getstatusoutput Python 2 中的函数在 Unix 和 Windows 上同样有效 不过我觉得这个方法有问题output被构建 它只返回输出的最后一行 但我不明白为什么 任何帮助都是极好的 def
  • 使用 kivy textinput 的 'input_type' 属性的问题

    您好 我在使用 kivy 的文本输入小部件的 input type 属性时遇到问题 问题是我制作了两个自定义文本输入 其中一个称为 StrText 其中设置了 input type text 然后是第二个文本输入 名为 NumText 其
  • 如何使用 Pandas、Numpy 加速 Python 中的嵌套 for 循环逻辑?

    我想检查一下表的字段是否TestProject包含了Client端传入的参数 嵌套for循环很丑陋 有什么高效简单的方法来实现吗 非常感谢您的任何建议 def test parameter a list parameter b list g
  • Python 2:SMTPServerDisconnected:连接意外关闭

    我在用 Python 发送电子邮件时遇到一个小问题 me my email address you recipient s email address me email protected cdn cgi l email protectio
  • 如何使用 Mysql Python 连接器检索二进制数据?

    如果我在 MySQL 中创建一个包含二进制数据的简单表 CREATE TABLE foo bar binary 4 INSERT INTO foo bar VALUES UNHEX de12 然后尝试使用 MySQL Connector P
  • Numpy - 根据表示一维的坐标向量的条件替换数组中的值

    我有一个data多维数组 最后一个是距离 另一方面 我有距离向量r 例如 Data np ones 20 30 100 r np linspace 10 50 100 最后 我还有一个临界距离值列表 称为r0 使得 r0 shape Dat
  • 加快网络抓取速度

    我正在使用一个非常简单的网络抓取工具抓取 23770 个网页scrapy 我对 scrapy 甚至 python 都很陌生 但设法编写了一个可以完成这项工作的蜘蛛 然而 它确实很慢 爬行 23770 个页面大约需要 28 小时 我看过scr
  • 不同编程语言中的浮点数学

    我知道浮点数学充其量可能是丑陋的 但我想知道是否有人可以解释以下怪癖 在大多数编程语言中 我测试了 0 4 到 0 2 的加法会产生轻微的错误 而 0 4 0 1 0 1 则不会产生错误 两者计算不平等的原因是什么 在各自的编程语言中可以采
  • 仅第一个加载的 Django 站点有效

    我最近向 stackoverflow 提交了一个问题 标题为使用mod wsgi在apache上多次请求后Django无限加载 https stackoverflow com questions 71705909 django infini
  • Python:XML 内所有标签名称中的字符串替换(将连字符替换为下划线)

    我有一个格式不太好的 XML 标签名称内有连字符 我想用下划线替换它 以便能够与 lxml objectify 一起使用 我想替换所有标签名称 包括嵌套的子标签 示例 XML
  • 将 Python 中的日期与日期时间进行比较

    所以我有一个日期列表 datetime date 2013 7 9 datetime date 2013 7 12 datetime date 2013 7 15 datetime date 2013 7 18 datetime date
  • 如何应用一个函数 n 次? [关闭]

    Closed 这个问题需要细节或清晰度 help closed questions 目前不接受答案 假设我有一个函数 它接受一个参数并返回相同类型的结果 def increment x return x 1 如何制作高阶函数repeat可以
  • Pandas 每周计算重复值

    我有一个Dataframe包含按周分组的日期和 ID df date id 2022 02 07 1 3 5 4 2022 02 14 2 1 3 2022 02 21 9 10 1 2022 05 16 我想计算每周有多少 id 与上周重
  • 更改 Tk 标签小部件中单个单词的颜色

    我想更改 Tkinter 标签小部件中单个单词的字体颜色 我知道可以使用文本小部件来实现与我想要完成的类似的事情 例如使单词 YELLOW 显示为黄色 self text tag config tag yel fg clr yellow s

随机推荐

  • Linux Power Supply架构及代码解析

    一 概述 电源管理整体上可以分为两个部分 一个是电池监控 fuel gauge 另外一个是充放电管理 这两部分在内核中也是分为两个驱动来管理 fuelgauge驱动的功能主要是负责向上层Android系统提供当前电池的电量和健康信息等等 同
  • redis学习总结

    文章目录 redis数据结构原理 简单字符串SDS 叫Simple dynamic string 链表 字典 跳跃表 redis持久化 RDB持久化 AOF持久化 redis集群三种模式 主从模式 实现主从分离 提高吞吐 多机备份 哨兵模式
  • Python填写问卷星

    主要使用python实现问卷星的自动填写和提交 主要使用了https www jianshu com p 34961ceedcb4的代码 使用了X Forwarded For自动修改ip 我测试的时候是可以使用的 PS 我是在linux下面
  • idea 设置自动添加注释

    添加类注释 打开Settings 点击Apply OK 添加方法注释 添加组 选择test 添加Live Template text如下 Author yeluo Description description param param re
  • JSONObject对象的方法

    JSONObject 是 org json 库中的一个类 用于创建和操作 JSON 对象 以下是一些常用的 JSONObject 方法 1 put key value 向 JSON 对象中添加键值对 jsonObject put key v
  • 锂电池充放电电路设计与分析

    Lithium battery charge 锂电池充放电电路 1 USB插入检测电路 1 1 FUSE1 自恢复保险丝 当后续的电路发生短路等故障时 自动启动保护作用来保护外围的电源 避免损坏 因为经常出事故一般是电源出事故了 电源短路
  • leetcode_第17题_缺失的第一个正数——原地哈希

    题目 题目 分析 正常思路 另外制作一个哈希表 然后遍历就ok 但是这样不符合题目空间复杂度要求 所以采用原地哈希就可以了 思路 把正常数字nums i 交换存储到下标位置为nums i 1的地方 不正常数字不管 正常数字是指 值 1 le
  • linux(ubuntu)下C++访问mysql数据库

    Ubuntu安装msyql 安装mysql数据库 1 sudo apt get install mysql server 安装mysql客户端
  • HTTP服务器(二)

    前面已经实现了服务器的整体框架 现在就来具体实现HTTP服务器处理静态页面的逻辑 要获取具体的静态文件 就要知道要获取的文件的路径 我们分析url 协议方案名 使用http 或https 等协议方案名获取访问资源时要指定的协议类型 登录信息
  • 1.mysql体系结构

    中文文档 mysql 5 1中文文档 一 MySql服务器和客户端 1 客户端和服务器服务器是指安装mysql的那台机器 而客户端是远程通过网络使用服务器上的mysql 客户端通过得知远程服务器的ip地址以及mysql的一些密码信息等使用m
  • “探秘JS加密算法:MD5、Base64、DES/AES、RSA你都知道吗?”

    目录 1 什么是JS JS反爬是什么 JS逆向是什么 2 JS逆向的大致流程 3 逆向的环境搭建 3 1 安装node js 3 2 安装js代码调试工具 vscode 3 3 安装PyExecJs模块 4 JS常见加密算法 4 1 Bas
  • Spring Boot 开启Giz

    Enable response compression server compression enabled true The comma separated list of mime types that should be compre
  • C++关键字

    注意单引号 a 97 A 65 include using namespace std
  • Java socket通信实例,简单入门socket实例代码

    是不是看了许多socket入门知识 却还是不能实际运用呢 这篇文章通过利用简单实例程序讲解通过socket实现客户端与服务器之间的通讯 这篇文章可以让你不需要了解socket原理也能利用 便于应急 但建议之后要好好补补关于soket的基础知
  • 安装jdk后HelloWorld测试

    编写HelloWorld java文件 源码如下 public class HelloWorld public static void main String args System out println Hello World in U
  • 数据库实验三 单表查询

    一 实验目的 理解SELECT语句的操作和基本使用方法 二 实验题目 1 查询全体学生的姓名 学号 所在系 SELECT Sname Sno Sdept FOEM studentflx 2 查询选修了课程的学生学号 SELECT DISTI
  • IDEA出现Please refer to dump files (if any exist) [date].dump, [date]-jvmRun[N].dump and [date].dumpst

    错误截图 解决方法 关了maven的运行检查就好了 maven的编译打包检查 关闭点一下就可以了 忽略检查测试文件
  • python模拟退火算法 水平耦合强度

    水平耦合强度 horizontal bonds 0 2242 0 8894 0 9625 1 3939 1 2604 1 7343 0 0290 0 0731 0 0770 0 4400 1 6270 0 0596 0 0690 0 119
  • 教你解决浏览器被360劫持篡改主页的麻烦

    前言 相信很多的小伙伴都遇到一个问题 就是好端端的 打开自己的edge或者Chrome 突然发现自己的主页变成了这样 下图 不得不说 这个看得人真的不适 晕 相信大部分人还是喜欢简洁的 而且主要的是 自己的浏览器被可恶的360给篡改了 真是
  • Torch 1.9.1 DDP 并行优化与多模块调用问题

    DDP 基础实现 由于 DataParallel DP 采取的是多线程并行 出于其特性 会造成通信瓶颈 GIL 限制 因此更高效的方式是使用 DistributedDataParallel 实现更高效的 GPU 使用 DDP 相关基础实现参