在 PyTorch 中使用 module.to() 移动成员张量

2024-01-03

我正在 PyTorch 中构建变分自动编码器 (VAE),但在编写与设备无关的代码时遇到问题。自动编码器是nn.Module具有编码器和解码器网络,它们也是。网络的所有权重都可以通过调用从一台设备移动到另一台设备net.to(device).

我遇到的问题是重新参数化技巧:

encoding = mu + noise * sigma

噪声是一个与以下大小相同的张量mu and sigma并保存为自动编码器模块的成员变量。它在构造函数中初始化,并在每个训练步骤中就地重新采样。我这样做是为了避免每一步构建一个新的噪声张量并将其推送到所需的设备。此外,我想修复评估中的噪音。这是代码:

class VariationalGenerator(nn.Module):
    def __init__(self, input_nc, output_nc):
        super(VariationalGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        embedding_size = 128

        self._train_noise = torch.randn(batch_size, embedding_size)
        self._eval_noise = torch.randn(1, embedding_size)
        self.noise = self._train_noise

        # Create encoder
        self.encoder = Encoder(input_nc, embedding_size)
        # Create decoder
        self.decoder = Decoder(output_nc, embedding_size)

    def train(self, mode=True):
        super(VariationalGenerator, self).train(mode)
        self.noise = self._train_noise

    def eval(self):
        super(VariationalGenerator, self).eval()
        self.noise = self._eval_noise

    def forward(self, inputs):
        # Calculate parameters of embedding space
        mu, log_sigma = self.encoder.forward(inputs)
        # Resample noise if training
        if self.training:
            self.noise.normal_()
        # Reparametrize noise to embedding space
        inputs = mu + self.noise * torch.exp(0.5 * log_sigma)
        # Decode to image
        inputs = self.decoder(inputs)

        return inputs, mu, log_sigma

当我现在将自动编码器移动到 GPU 时net.to('cuda:0')我在转发时遇到错误,因为噪声张量没有移动。

我不想向构造函数添加设备参数,因为这样以后仍然无法将其移动到另一个设备。我也尝试将噪音包裹起来nn.Parameter从而使其受到影响net.to(),但这会给优化器带来错误,因为噪声被标记为requires_grad=False.

任何人都有一个解决方案来移动所有模块net.to()?


更好的版本tilman151的第二种方法 https://stackoverflow.com/a/54768936/344821可能是覆盖_apply, 而不是to。那样net.cuda(), net.float()等都可以工作,因为它们都调用_apply而不是to(可以看出来源 https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py,这比你想象的要简单):

def _apply(self, fn):
    super(VariationalGenerator, self)._apply(fn)
    self._train_noise = fn(self._train_noise)
    self._eval_noise = fn(self._eval_noise)
    return self
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 PyTorch 中使用 module.to() 移动成员张量 的相关文章

随机推荐

  • ILASM 未设置文件版本

    我有一个 il 文件 可以毫无问题地编译它 我可以很清楚地命名它 所以没有任何问题 但我无法按照我的预期通过属性设置文件版本 使用 ilasm 时如何设置程序集的文件版本 如果我进行往返 我总是会得到一个 res 文件 该文件仅包含不可读的
  • 在 Celery 链中使用分组结果

    我陷入了相对复杂的芹菜链配置 试图实现以下目标 假设有如下一系列任务 chain1 chain DownloadFile s http someserver file gz downloads file returns temp file
  • Angular 4 - 如何显示继承类组件的模板

    我正在尝试根据项目 组件 的类型显示项目列表 我有一系列组件 全部继承自基类 数组类型被定义为基类的类型 我想显示数组 比如说作为项目列表 每个数组都有自己的模板 而不是基本模板 我已经尝试过 在app component html中
  • 调用 `this.setState()` 会中断对 componentWillReceiveProps 中 prop 的流类型检查

    当我调用时 我在知道是字符串的 prop 上遇到流错误this setState 就在它之前 如果我移动setState 在使用 prop 的行之后调用 错误就会消失 我收到的错误是 null 此类型与预期的字符串参数类型不兼容 不明确的
  • 在 contenteditable 元素中,在 HTML 标签之间移动光标

    http jsfiddle net Y7tgx 2 http jsfiddle net Y7tgx 2 Firefox 比 Chrome 处理得更好 但都不完全是我想要的方式 它们都将所有相邻的 HTML 标签集中在一起并将它们视为一个 我
  • 在 C++11 中禁用复制类的最简洁方法

    当存在用户定义的析构函数时 我在处理自 C 11 默认生成的复制构造函数和复制赋值运算符以来已弃用的问题 对于大多数足够简单的类 默认生成的构造函数 运算符和析构函数都可以 考虑以下声明析构函数的原因 在基类中将普通析构函数设为虚拟 hea
  • Ember 的 registerBoundHelper 和车把块

    所以我已经从here https github com danharper Handlebars Helpers并修改它 以便它使用 registerBoundHelper 通过 Ember 注册它的助手 我这样做的原因是因为我基本上需要一
  • *** _pickle.UnpicklingError:pickle 数据被截断

    我有一个包含一千个 pickle 文件的目录 我将它们一一加载 如下所示 我正在使用 python3 import pickle for data in directory with open data rb as handle pickl
  • 过度使用泛型[关闭]

    很难说出这里问的是什么 这个问题是含糊的 模糊的 不完整的 过于宽泛的或修辞性的 无法以目前的形式得到合理的回答 如需帮助澄清此问题以便重新打开 访问帮助中心 help reopen questions 当没有明显的实际好处时 人们对泛型的
  • 从 CSV 中删除行

    我有一个包含多个标题的 csv 文档 例如 Date RQ PM SME Activity Status code 2 2 12 6886 D WV John Smith Recent 2004 以及一个文本文档 它只是状态代码的列表 每行
  • 将对象传递给网络工作者

    我正在尝试通过 postMessage 函数将对象传递给网络工作者 这个对象是一个正方形 有几个功能可以在画布上和其他东西上绘制自己 Web Worker 必须返回此对象的数组 问题是 当我使用该对象调用 postMessage 函数时 出
  • 将 SendGrid 与 appharbor 一起使用时出现问题

    我正在使用 appharbor 添加 SendGrid 作为插件 他们为我提供了 smtp 主机 smtp sendgrid com 端口 587 用户 32adf793 2cbf 492c 9bb9 apphb com 当我使用这些详细信
  • Collectors#toList 的运行时复杂性

    在Java库源代码中 Collectors toList方法定义如下 public static
  • Mailchimp 注册表单与 angular2 [重复]

    这个问题在这里已经有答案了 我正在尝试将 mailchimp 注册表单嵌入到我的 angular2 应用程序中 http kb mailchimp com lists signup forms add a signup form to yo
  • 取消部署出现错误:应用程序未注册(Glassfish)

    我使用 Glassfish 的 Web GUI 取消部署了我的应用程序 但是 如果我按取消部署我的实际应用程序 则什么也不会发生 在我的日志文件中出现新错误 应用程序未注册 严重 我现在如何正确取消部署我的APP 我找到了一个简单的答案 从
  • SWT:单显示器与多显示器

    SWT 旨在支持多种Display实例 每个实例都有自己的事件循环 这对于什么目的是有用的或需要的 不是一个Display实例 例如Display getDefault 充足的 Display 类的文档说 使用 SWT 构建的应用程序几乎总
  • 无法在 Safari 或 UIWebView 中通过 HTTPS 查看 Quicktime 影片

    我试图让我的 iPhone 应用程序除了 HTTP 之外还可以使用 HTTPS 但使用 UIWebView 或 MPMoviePlayerController 查看 Quicktime MOV 文件似乎无法通过 HTTPS 工作 我得到 这
  • Android 中 OOM(内存不足异常)是如何发生的?

    我正在尝试显示来自画廊的图像或从相机捕获的图像ImageView 我开始得到OOM正在处理中 所以我决定找出它是如何工作的 所以我尝试使用不同尺寸的图像 这是观察结果 我尝试将 19KB 的图像加载到ImageView并收到以下错误消息 无
  • Flutter:在“bottomNavigationBar”上显示“showBottomSheet”

    我如何在 bottomNavigationBar 顶部显示 showBottomSheet 说明性示例 当用户单击图钉 屏幕1 时 结果是 屏幕2 但我想得到 屏幕3 return Scaffold appBar AppBar title
  • 在 PyTorch 中使用 module.to() 移动成员张量

    我正在 PyTorch 中构建变分自动编码器 VAE 但在编写与设备无关的代码时遇到问题 自动编码器是nn Module具有编码器和解码器网络 它们也是 网络的所有权重都可以通过调用从一台设备移动到另一台设备net to device 我遇