Pytorch多GPU训练时使用hook提取模型中间层输出时与模型输入张量不在同一个GPU上的解决办法

2023-05-16

Pytorch多GPU训练时使用hook提取模型中间层输出时与模型输入张量不在同一个GPU上的解决办法

通常对于单卡训练的模型,使用hook可以较为方便地提取出模型中间层输出。
例如我们想要获取自定义模型DBL中的conv2d的输出,可以先打印出这个网络,获取到conv2d在模型中的次序,然后使用for循环确定其位置并注册hook。
参考https://www.jianshu.com/p/0a270d63aca9

import torch
import torch.nn as nn

class CBL(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, groups):
        super(CBL, self).__init__()
        pad = (kernel_size - 1) // 2
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                      kernel_size=kernel_size, stride=1, padding=pad,
                      groups=groups, bias=False),
            nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5),
            nn.LeakyReLU(0.1),
        )

        ## hook相关代码
        self.mid_fea = []
        for index_i, (name, module) in enumerate(self.named_modules()):
            if index_i == 0:        # conv在模型中的序号是0
                module.register_forward_hook(hook=self.layer_hook)
                # 必须在前向推理之前声明hook
                break
                
    def layer_hook(self, module, fea_in, fea_out):
        self.mid_fea.append(fea_out)

    def forward(self, x):
        out = self.conv(x)
        return out

if __name__ == "__main__":
	# 这里为了方便没有使用gpu
    model = CBL(8, 16, 3, 1)
    x = torch.ones(1, 8, 10, 10)
    out = model(x)
    print(model.mid_fea[0])

然而当我们使用多个GPU训练模型时,上述方法得到的中间层输出可能总是与模型输入张量不在同一个gpu上,这可能会导致后续的计算报错。即使使用to(device),似乎总是不能把中间层输出移动到指定的gpu上。查了半天,论坛上给出了一个解决方法:不要使用列表保存中间层输出,而是使用字典,将不同的device上的中间层分别存放。示例如下
参考网址:https://discuss.pytorch.org/t/register-forward-hook-with-multiple-gpus/12115

import torch
import torch.nn as nn


class CBL(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, groups):
        super(CBL, self).__init__()
        pad = (kernel_size - 1) // 2
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                      kernel_size=kernel_size, stride=1, padding=pad,
                      groups=groups, bias=False),
            nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5),
            nn.LeakyReLU(0.1),
        )

        ## hook相关代码
        self.mid_fea = {}
                for index_i, (name, module) in enumerate(self.named_modules()):
            if index_i == 1:        # conv在模型中的序号是1
                module.register_forward_hook(hook=self.layer_hook)
                # 必须在前向推理之前声明hook
                break

    def layer_hook(self, module, fea_in, fea_out):
        self.mid_fea[fea_in[0].device].append(fea_out)

    def forward(self, x):
        self.mid_fea[x.device] = []
        out = self.conv(x)
        return out, self.mid_fea[x.device][0]		# 返回模型输出以及中间层特征


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CBL(8, 16, 3, 1).to(device)
    model = nn.DataParallel(model)          # 使用多张gpu

    x = torch.ones(2, 8, 10, 10)
    out, mid_fea = model(x)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch多GPU训练时使用hook提取模型中间层输出时与模型输入张量不在同一个GPU上的解决办法 的相关文章

  • 连接 Hibernate 的查询生成

    我想实施虚拟视图与预处理器 一个简单的例子 之前的HQL FROM PublishedArticle a 生效后的 HQL FROM Article a WHERE a published true 本质上 我需要一种在执行查询之前处理查询
  • 从活动顶点数组生成平滑法线

    我正在尝试通过挂钩 OpenGl 调用来破解和修改旧版 opengl 固定管道游戏的多个渲染功能 而我当前的任务是实现着色器照明 我已经创建了一个适当的着色器程序 可以正确照亮大部分对象 但该游戏的地形是在没有提供正常数据的情况下绘制的 游
  • 同时使用 2 个 GPU 调用 cudaMalloc 时性能较差

    我有一个应用程序 可以在用户系统上的 GPU 之间分配处理负载 基本上 每个 GPU 都有一个 CPU 线程来启动一个GPU处理间隔当由主应用程序线程定期触发时 考虑以下图像 使用 NVIDIA 的 CUDA 分析器工具生成 作为示例GPU
  • 使 CUDA 内存不足

    我正在尝试训练网络 但我明白了 我将批量大小设置为 300 并收到此错误 但即使我将其减少到 100 我仍然收到此错误 更令人沮丧的是 在 1200 个图像上运行 10 epoch 大约需要 40 分钟 有什么建议吗 错了 我怎样才能加快这
  • 如何计算 CNN 第一个线性层的维度

    目前 我正在使用 CNN 其中附加了一个完全连接的层 并且我正在使用尺寸为 32x32 的 3 通道图像 我想知道是否有一个一致的公式可以用来计算第一个线性层的输入尺寸和最后一个卷积 最大池层的输入 我希望能够计算第一个线性层的尺寸 仅给出
  • 将 Keras (Tensorflow) 卷积神经网络转换为 PyTorch 卷积网络?

    Keras 和 PyTorch 使用不同的参数进行填充 Keras 需要输入字符串 而 PyTorch 使用数字 有什么区别 如何将一个转换为另一个 哪些代码在任一框架中获得相同的结果 PyTorch 还采用参数 in channels o
  • NvCplGetThermalSettings 返回 false

    问题 您好 我正在尝试使用 Delphi 获取 nividia gtx 980 的 GPU 温度 我看过C 问题 他的解决方案是不使用nvcpl dll 我认为这不是正确的解决方案 因为 nivida 有完整的文档说明如何处理 API 见下
  • 如何有效地对一个数组中某个值在另一个数组中的位置出现的次数求和

    我正在寻找一种有效的 for 循环 避免解决方案来解决我遇到的数组相关问题 我想使用一个巨大的一维数组 A gt size 250 000 用于一维索引的 0 到 40 之间的值 以及用于第二维索引的具有 0 到 9995 之间的值的相同大
  • Pytorch CUDA 错误:没有内核映像可用于在带有 cuda 11.1 的 RTX 3090 设备上执行

    如果我运行以下命令 import torch import sys print A sys version print B torch version print C torch cuda is available print D torc
  • torch.stack() 和 torch.cat() 函数有什么区别?

    OpenAI 的强化学习 REINFORCE 和 actor critic 示例具有以下代码 加强 https github com pytorch examples blob master reinforcement learning r
  • iOS 上的 OpenCV - GPU 使用情况?

    我正在尝试开发一个 iOS 应用程序 可以对来自相机的视频执行实时效果 就像 iPad 上的 Photobooth 一样 我熟悉 OpenCV 的 API 但如果大多数处理是在 CPU 上完成而不是在 GPU 上完成 我担心 iOS 上的性
  • Android应用程序:java / JNI调用挂钩策略

    我的目标是检测 AOSP 以便动态记录来自目标应用程序的所有 java 或 JNI 调用 带或不带参数和返回值 我不想修改应用程序 这就是我想要修改 Android 源代码的原因 我对 AOSP 及其众多库和框架不是很有经验 所以我正在寻求
  • pytorch 的 IDE 自动完成

    我正在使用 Visual Studio 代码 最近尝试了风筝 这两者似乎都没有 pytorch 的自动完成功能 这些工具可以吗 如果没有 有人可以推荐一个可以的编辑器吗 谢谢你 使用Pycharmhttps www jetbrains co
  • 如何从已安装的云端硬盘文件夹中永久删除?

    我编写了一个脚本 在每次迭代后将我的模型和训练示例上传到 Google Drive 以防发生崩溃或任何阻止笔记本运行的情况 如下所示 drive path drive My Drive Colab Notebooks models if p
  • 从 CUDA 设备写入输出文件

    我是 CUDA 编程的新手 正在将 C 代码重写为并行 CUDA 新代码 有没有一种方法可以直接从设备写入输出数据文件 而无需将数组从设备复制到主机 我假设如果cuPrintf存在 一定有地方可以写一个cuFprintf 抱歉 如果答案已经
  • 在 Pytorch 中估计高斯模型的混合

    我实际上想估计一个以高斯混合作为基本分布的归一化流 所以我有点被火炬困住了 但是 您可以通过估计 torch 中高斯模型的混合来在代码中重现我的错误 我的代码如下 import numpy as np import matplotlib p
  • 如何在 PyTorch 中对子集使用不同的数据增强

    如何针对不同的情况使用不同的数据增强 转换 Subset在 PyTorch 中吗 例如 train test torch utils data random split dataset 80000 2000 train and test将具
  • 如何为 CUDA 内核选择网格和块尺寸?

    这是一个关于如何确定CUDA网格 块和线程大小的问题 这是对已发布问题的附加问题here https stackoverflow com a 5643838 1292251 通过此链接 talonmies 的答案包含一个代码片段 见下文 我
  • TensorFlow 相当于 PyTorch 的 Transforms.Normalize()

    我正在尝试推断最初在 PyTorch 中构建的 TFLite 模型 我一直在遵循PyTorch 实现 https github com leoxiaobin deep high resolution net pytorch blob 1ee
  • PyTorch 给出 cuda 运行时错误

    我对我的代码做了一些小小的修改 以便它不使用 DataParallel and DistributedDataParallel 代码如下 import argparse import os import shutil import time

随机推荐

  • 阿里云服务器

    一年多之前 xff0c 也就11年5月份的样子 xff0c 阿里云云服务器产品线终于上线了 但那时候 xff0c 国内完全没有能称得上云服务器的 xff0c 很多小公司就是搞个VPS就叫云服务器了 以至于阿里云云服务器刚出来的时候 xff0
  • mac 下 使用 iterm2 配置及快键键使用

    mac 下 使用 iterm2 配置及快键键使用 标签 xff08 空格分隔 xff09 xff1a mac 之前介绍过一篇关于mac 下使用和配置 iterm2的blog 今天这篇稍微详细一点介绍 并且搭配 zsh zsh 会单独开一篇博
  • Java实现快速排序

    一 原理 快速排序算法通过多次比较和交换来实现排序 xff0c 其排序流程如下 xff1a 1 首先设定一个分界值 xff0c 通过该分界值将数组分成左右两部分 2 将大于或等于分界值的数据集中到数组右边 xff0c 小于分界值的数据集中到
  • C#,生信软件实践(03)——DNA数据库GenBank格式详解及转为FASTA序列格式的源代码

    1 GenBank 1 1 NCBI 美国国家生物技术信息中心 xff08 美国国立生物技术信息中心 xff09 NCBI xff08 美国国立生物技术信息中心 xff09 是在NIH的国立医学图书馆 xff08 NLM xff09 的一个
  • 【坑】zsh和oh-my-zsh卸载后导致无法登陆

    apt get remove zsh 然后断开终端 xff0c 就再也连不上了 xff0c 崩溃啊 xff01 以下登陆为www用户登陆 各种找 xff0c 到这里 https www cnblogs com EasonJim p 7863
  • 获取最近使用应用列表

    获取最近使用的应用列表需要使用到UsageStatsManager类 xff0c 还需要申请允许防御应用使用情况的权限 private void getPackagesInfo UsageStatsManager manager 61 Us
  • 使用MediaProjectionManager进行截屏

    最近项目中有用到远程截屏并上传截屏文件的需求 一开始使用的是以下方法进行截屏 xff1a private void screenshot 获取屏幕 View dView 61 getWindow getDecorView dView set
  • 安卓TV开发遇到的那些坑

    最近公司需要开发一个TV的luancher xff0c 就是那种纯物理按键的遥控 xff0c 没有触摸屏 xff0c 现在说说我踩得那些坑 xff08 其实布局和代码逻辑和正常的安卓应用差不多 xff09 1 焦点 焦点 焦点 xff0c
  • 安卓TV列表刷新时焦点自动变成第一个

    最近在开发安卓TV项目 xff0c 列表调用notifyDataSetChanged xff08 xff09 方法刷新数据时 xff0c 焦点自动就变成第一个子item去了 xff0c 查了半天发现用notifyItemRangeChang
  • 安卓蓝牙BLE设备通讯发送和接受超过20个字节的问题

    最近做的项目是手机端和BLE设备通讯 xff0c 而BLE设备又做了限制一次包只能传递20个字节的数据 xff0c 多了就得分包发送 xff0c 在这里记录一下如何解决这个问题 xff08 PS xff1a 之前链接什么的回调什么的 就不过
  • 获取最近运行应用方法和杀进程的方法

    最近公司的项目有个需求就是获取最近手机正在运行的进程 xff0c 以及杀掉进程 就是类似于安卓手机中的长按home键的效果 先说说获取最近手机正在运行的进程方法 xff1a 直接上代码 xff0c 代码中有注释 xff1a appbeans
  • 把自己的应用程序push至system/app下,把自己的app改成系统级别的app

    想把一个应用程序放入到系统文件夹下的话 xff0c 手机必须的root的情况下才能push进去 下面我就说说步骤吧 xff1a 1 先把手机用USB和电脑连接 2 如果电脑配置了adb的环境的话直接cmd xff0c 未配置环境的话找到sd
  • ConcurrentModificationException异常出现原因以及解决方法

    今天在开发过程中遇到一个异常叫ConcurrentModificationException xff0c 这个异常用我的白话翻译是叫同时修改异常 这个异常是怎么出现的呢 xff0c 先看看已下的代码 xff1a span class hlj
  • retrofit中使用body标签传RequestBody

    现在的Android开发者基本上都用过retrofit这个第三方网络请求库吧 xff01 xff01 xff01 网络请求中有get post delete和put等等请求方式 现在我们需要用到post请求 xff1a span class
  • SpringBoot配置拦截器拦截器使用

    拦截器介绍 Java中的拦截器是动态拦截 action 调用的对象 xff0c 然后提供了可以在 action 执行前后增加一些操作 xff0c 也可以在 action执行前停止操作 xff0c 功能与过滤器类似 xff0c 但是标准和实现
  • 百度地图上根据经纬度集合绘制行车轨迹

    以下是素材 最近项目中用到了根据一段线路的经纬度集合来在地图上播放该车辆的行驶轨迹的需求 下面我就讲一下我实现步骤 效果图如下 因为制作gif图为了控制大小去掉了很多帧 不必在意这些细节 嘿嘿 1 首先在界面上展示百度地图 这不是废话么 如
  • skip-GANomaly复现总结

    文章目录 skip GANomaly复现总结附MvTec数据集介绍实验结果总结谈谈我对于skip GANomaly的看法最后的感想 代码 skip GANomaly复现总结 附MvTec数据集 链接 xff1a https pan baid
  • YOLOv3 从入门到部署:(五)YOLOv3模型的部署(基于C++ opencv)

    文章目录 YOLOv3 从入门到部署 xff1a xff08 五 xff09 YOLOv3模型的部署 xff08 基于C 43 43 opencv xff09 目录关于opencv的DNN介绍代码讲解效果展示 YOLOv3 从入门到部署 x
  • 基于YOLO-fastest-xl的OCR

    文章目录 基于YOLO fastest xl的OCR项目介绍对于yolo fastest xl的结构的更改运行方法效果总结 基于YOLO fastest xl的OCR github链接https github com qqsuhao yol
  • Pytorch多GPU训练时使用hook提取模型中间层输出时与模型输入张量不在同一个GPU上的解决办法

    Pytorch多GPU训练时使用hook提取模型中间层输出时与模型输入张量不在同一个GPU上的解决办法 通常对于单卡训练的模型 xff0c 使用hook可以较为方便地提取出模型中间层输出 例如我们想要获取自定义模型DBL中的conv2d的输