为什么 PyTorch nn.Module.cuda() 不将模块张量移动到 GPU,而仅将参数和缓冲区移动到 GPU?

2024-04-12

nn.Module.cuda()将所有模型参数和缓冲区移动到 GPU。

但为什么不是模型成员张量呢?

class ToyModule(torch.nn.Module):
    def __init__(self) -> None:
        super(ToyModule, self).__init__()
        self.layer = torch.nn.Linear(2, 2)
        self.expected_moved_cuda_tensor = torch.tensor([0, 2, 3])

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self.layer(input)

toy_module = ToyModule()
toy_module.cuda()
next(toy_module.layer.parameters()).device
>>> device(type='cuda', index=0)

对于模型成员张量,设备保持不变。

>>> toy_module.expected_moved_cuda_tensor.device
device(type='cpu')

如果您在模块内定义张量,则需要将其注册为参数或缓冲区,以便模块知道它。


参数是要训练的张量,并将通过以下方式返回model.parameters()。它们很容易注册,您所需要做的就是将张量包装在nn.Parameter输入,它将自动注册。请注意,只有浮点张量可以作为参数。

class ToyModule(torch.nn.Module):
    def __init__(self) -> None:
        super(ToyModule, self).__init__()
        self.layer = torch.nn.Linear(2, 2)
        # registering expected_moved_cuda_tensor as a trainable parameter
        self.expected_moved_cuda_tensor = torch.nn.Parameter(torch.tensor([0., 2., 3.]))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self.layer(input)

Buffers是将在模块中注册的张量,因此方法如下.cuda()会影响他们,但他们会not被返回model.parameters()。缓冲区不限于特定的数据类型。

class ToyModule(torch.nn.Module):
    def __init__(self) -> None:
        super(ToyModule, self).__init__()
        self.layer = torch.nn.Linear(2, 2)
        # registering expected_moved_cuda_tensor as a buffer
        # Note: this creates a new member variable named expected_moved_cuda_tensor
        self.register_buffer('expected_moved_cuda_tensor', torch.tensor([0, 2, 3])))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self.layer(input)

在上述两种情况下,以下代码的行为相同

>>> toy_module = ToyModule()
>>> toy_module.cuda()
>>> next(toy_module.layer.parameters()).device
device(type='cuda', index=0)
>>> toy_module.expected_moved_cuda_tensor.device
device(type='cuda', index=0)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

为什么 PyTorch nn.Module.cuda() 不将模块张量移动到 GPU,而仅将参数和缓冲区移动到 GPU? 的相关文章

  • 理解Python中的元类和继承[重复]

    这个问题在这里已经有答案了 我对元类有一些困惑 具有继承性 class AttributeInitType object def init self kwargs for name value in kwargs items setattr
  • 隐藏控制台并执行 python 脚本

    我正在尝试使用 pyinstaller 在 Windows 10 上使用 pyqt5 模块编译在 python 3 中构建的 python 脚本 该脚本在运行时隐藏窗口 为了编译我的脚本 我执行了以下命令 pyinstaller onefi
  • MySQL 的 read_sql() 非常慢

    我将 MySQL 与 pandas 和 sqlalchemy 一起使用 然而 它的速度非常慢 对于一个包含 1100 万行的表 一个简单的查询需要 11 分钟以上才能完成 哪些行动可以改善这种表现 提到的表没有主键 并且仅由一列索引 fro
  • 如何/在哪里发布 Python 包

    如果一个人创建了一个有用的 Python 包 那么如何 在哪里发布 宣传它以供其他人使用 我已经把它放到了 github 上 但几周后谷歌也没有找到它 包装整洁完整 我制作它供我个人使用 不与其他人分享将是一种耻辱 这是 PyPI 指南 h
  • Python:如何删除圆括号内的文本?

    我试过了 但没用 return re sub myResultStats text 建议 thanks 尝试这个 return re sub myResultStats text 括号表示捕获组 因此您必须转义它们
  • 通过 pyodbc 连接到 Azure SQL 数据库

    我使用 pyodbc 连接到本地 SQL 数据库 该数据库工作正常 SQLSERVERLOCAL Driver SQL Server Native Client 11 0 Server localdb v11 0 integrated se
  • 使用python同时播放两个正弦音

    我正在使用 python 来播放正弦音 音调基于计算机的内部时间 以分钟为单位 但我想根据秒同时播放一个音调 以获得和谐或双重的声音 这就是我到目前为止所拥有的 有人能指出我正确的方向吗 from struct import pack fr
  • 将带有非字符串关键字的 dict 传递给 kwargs 中的函数

    我使用具有签名功能的库f args kwargs 我需要在 kwargs 参数中传递 python dict 但 dict 不包含关键字中的字符串 f 1 2 3 4 Traceback most recent call last File
  • 将图像转换为二进制流

    我的应用程序有两个方面 一方面我使用 C 来使用 Pleora 的 EBUS SDK 从相机读取帧 当第一次接收到该流时 在将缓冲区转换为图像之前 我能够一次读取 16 位流 以便对每个像素执行一些计算 即每个像素都存在一个 16 位数据块
  • 可重用的 Tensorflow 卷积网络

    我想重用来自Tensorflow 专业人士的 MNIST CNN 示例 http www tensorflow org tutorials mnist pros index md 我的图像尺寸为 388px X 191px 只有 2 个输出
  • python中remove方法的安全使用

    我从列表继承了一个 UserList 类并实现了以下方法来删除标记为已删除的条目 def purge deleted self for element in list iter self if ele mark deleted lt 1 s
  • 从文件中读取单词并放入列表中

    本质上 我有一个巨大的文件 所有文件包含每行多个单词 每个单词用空格分隔 有点像这样 WORD WORD WORD WORD ANOTHER WORD SCRABBLE BLAH YES NO 我想要做的是将文件中的所有单词放入一个巨大的列
  • Python、cPickle、酸洗 lambda 函数

    我必须像这样腌制一组对象 import cPickle as pickle from numpy import sin cos array tmp lambda x sin x cos x test array tmp tmp tmp tm
  • 在 Python 中将 int 转换为 ASCII 并返回

    我正在为我的网站制作一个 URL 缩短器 我当前的计划 我愿意接受建议 是使用节点 ID 来生成缩短的 URL 因此 理论上 节点 26 可能是short com z 节点 1 可能是short com a 节点 52 可能是short c
  • 使用 pythons strftime 显示日期,例如“5 月 5 日”? [复制]

    这个问题在这里已经有答案了 可能的重复 Python 日期顺序输出 https stackoverflow com questions 739241 python date ordinal output 在Python中 time strf
  • 使用 PIL 合并图像时模式不匹配

    我正在传递 jpg 文件的名称 def split image into bands filename img Image open filename data img getdata red d 0 0 0 for d in data L
  • 使用 PuLP 进行线性优化,变量附加条件

    我必须用 Pull 解决 Python 中的整数线性优化问题 我解决了基本问题 现在我必须添加额外的约束 有人可以帮助我用逻辑指示器添加条件吗 逻辑限制是 如果 A gt 20 则 B gt 5 这是我的代码 from pulp impor
  • 矩阵求逆 (3,3) python - 硬编码与 numpy.linalg.inv

    对于大量矩阵 我需要计算定义为的距离度量 尽管我确实知道强烈建议不要使用矩阵求逆 但我没有找到解决方法 因此 我尝试通过对矩阵求逆进行硬编码来提高性能 因为所有矩阵的大小均为 3 3 我预计这至少会是一个微小的改进 但事实并非如此 为什么
  • 应用程序的外观 - Py2exe / wxPython

    所以我的问题是我的应用程序的外观和感觉 因为它看起来像一个旧的外观应用程序 它是一个 wxPython 应用程序 在 python 上它运行良好并且看起来不错 但是当我使用 py2exe 将其转换为 exe 时 外观很糟糕 现在我知道如果你
  • 如何访问模板缓存? - 姜戈

    I am 缓存 HTML在几个模板内 例如 cache 900 stats stats endcache 我可以使用以下方式访问缓存吗低级图书馆 例如 html cache get stats 我确实需要对模板缓存进行一些细粒度的控制 有任

随机推荐