PyTorch 中的截断反向传播(代码检查)

2024-03-14

我正在尝试在 PyTorch 中实现随时间截断的反向传播,对于以下简单情况K1=K2。我下面有一个实现可以产生合理的输出,但我只是想确保它是正确的。当我在网上查找 TBTT 的 PyTorch 示例时,它们在分离隐藏状态、将梯度归零以及这些操作的顺序方面做了不一致的事情。如果我犯了错误,请告诉我。

在下面的代码中,H保持当前的隐藏状态,并且model(weights, H, x)输出预测和新的隐藏状态。

while i < NUM_STEPS:
    # Grab x, y for ith datapoint
    x = data[i]
    target = true_output[i]

    # Run model
    output, new_hidden = model(weights, H, x)
    H = new_hidden

    # Update running error
    error += (output - target)**2

    if (i+1) % K == 0:
        # Backpropagate
        error.backward()
        opt.step()
        opt.zero_grad()
        error = 0
        H = H.detach()

    i += 1

因此,代码的想法是在每个第 K 步之后隔离最后一个变量。是的,你的实现是绝对正确的,这answer https://discuss.pytorch.org/t/correct-way-to-do-backpropagation-through-time/11701/3证实了这一点。

# truncated to the last K timesteps
while i < NUM_STEPS:
    out = model(out)
    if (i+1) % K == 0:
        out.backward()
        out.detach()
out.backward()

您还可以关注this https://github.com/pytorch/ignite/blob/master/ignite/contrib/engines/tbptt.py示例供您参考。

import torch

from ignite.engine import Engine, EventEnum, _prepare_batch
from ignite.utils import apply_to_tensor


class Tbptt_Events(EventEnum):
    """Aditional tbptt events.

    Additional events for truncated backpropagation throught time dedicated
    trainer.
    """

    TIME_ITERATION_STARTED = "time_iteration_started"
    TIME_ITERATION_COMPLETED = "time_iteration_completed"


def _detach_hidden(hidden):
    """Cut backpropagation graph.

    Auxillary function to cut the backpropagation graph by detaching the hidden
    vector.
    """
    return apply_to_tensor(hidden, torch.Tensor.detach)


def create_supervised_tbptt_trainer(
    model, optimizer, loss_fn, tbtt_step, dim=0, device=None, non_blocking=False, prepare_batch=_prepare_batch
):
    """Create a trainer for truncated backprop through time supervised models.

    Training recurrent model on long sequences is computationally intensive as
    it requires to process the whole sequence before getting a gradient.
    However, when the training loss is computed over many outputs
    (`X to many <https://karpathy.github.io/2015/05/21/rnn-effectiveness/>`_),
    there is an opportunity to compute a gradient over a subsequence. This is
    known as
    `truncated backpropagation through time <https://machinelearningmastery.com/
    gentle-introduction-backpropagation-time/>`_.
    This supervised trainer apply gradient optimization step every `tbtt_step`
    time steps of the sequence, while backpropagating through the same
    `tbtt_step` time steps.

    Args:
        model (`torch.nn.Module`): the model to train.
        optimizer (`torch.optim.Optimizer`): the optimizer to use.
        loss_fn (torch.nn loss function): the loss function to use.
        tbtt_step (int): the length of time chunks (last one may be smaller).
        dim (int): axis representing the time dimension.
        device (str, optional): device type specification (default: None).
            Applies to batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU,
            the copy may occur asynchronously with respect to the host. For other cases,
            this argument has no effect.
        prepare_batch (callable, optional): function that receives `batch`, `device`,
            `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`.

    .. warning::

        The internal use of `device` has changed.
        `device` will now *only* be used to move the input data to the correct device.
        The `model` should be moved by the user before creating an optimizer.

        For more information see:

        * `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_
        * `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_

    Returns:
        Engine: a trainer engine with supervised update function.

    """

    def _update(engine, batch):
        loss_list = []
        hidden = None

        x, y = batch
        for batch_t in zip(x.split(tbtt_step, dim=dim), y.split(tbtt_step, dim=dim)):
            x_t, y_t = prepare_batch(batch_t, device=device, non_blocking=non_blocking)
            # Fire event for start of iteration
            engine.fire_event(Tbptt_Events.TIME_ITERATION_STARTED)
            # Forward, backward and
            model.train()
            optimizer.zero_grad()
            if hidden is None:
                y_pred_t, hidden = model(x_t)
            else:
                hidden = _detach_hidden(hidden)
                y_pred_t, hidden = model(x_t, hidden)
            loss_t = loss_fn(y_pred_t, y_t)
            loss_t.backward()
            optimizer.step()

            # Setting state of engine for consistent behaviour
            engine.state.output = loss_t.item()
            loss_list.append(loss_t.item())

            # Fire event for end of iteration
            engine.fire_event(Tbptt_Events.TIME_ITERATION_COMPLETED)

        # return average loss over the time splits
        return sum(loss_list) / len(loss_list)

    engine = Engine(_update)
    engine.register_events(*Tbptt_Events)
    return engine
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch 中的截断反向传播(代码检查) 的相关文章

随机推荐

  • Galaxy Tab 出现奇怪的性能问题

    我正在编写 2d 教程 并且能够在 Samsung Galaxy Tab 上测试我当前的教程部分 本教程只是在屏幕上随机移动默认图标 通过点击 我创建了一个新的移动图标 只要屏幕上有 25 个或更少的元素 Galaxy 上的一切都可以正常运
  • Linux 上 Objective-C 的 IDE [已关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 我正在学习 Objective C 我想知道在哪里可以找到 Linux 上 Objective C 的
  • 加载逗号后空格不一致的 CSV 文件

    我想使用加载 CSV 文件LOAD DATA INFILE命令 但逗号后面的空格不一致 即有些逗号后面跟着空格 有些逗号后面没有空格 我尝试使用FIELDS TERMINATED BY 指令 但结果表中的某些字段包含前导空格 如果输入是 a
  • 如何将我自己的存储库分叉到新项目中?

    我正在开发一个 HTML5 游戏引擎 我使用 Git 作为 SV 并使用 GitHub 来实际托管该项目 我在设计上做了一些实质性的改变 主要是切换到实体系统范例 我认为是时候换一个新引擎了 我想将它建立在旧引擎的基础上 因为我可以使用很多
  • Javascript:添加动态方法的更好方法?

    我想知道是否有更好的方法向现有对象添加动态方法 基本上 我试图动态地组装新方法 然后将它们附加到现有函数中 该演示代码有效 builder function fn methods method builder for p in method
  • 加载 JSON 文件时出现内存错误

    当我加载 500Mo 大的 JSON 文件时 Python 和间谍程序 返回 MemoryError 但我的电脑有 32Go RAM 当我尝试加载它时 spyder 显示的 内存 从 15 变为 19 看来我应该有更多的空间 有什么我没想到
  • 将网络抓取的响应保存为 csv 文件

    我从网站下载了一个文件rvest 如何将回复另存为csv file Step 1 猴子补丁rvest像这个线程中的包 如何在 Rvest 包中提交登录表单 不带按钮参数 https stackoverflow com questions 3
  • 如何在silverlight3.0中播放Youtube视频

    我正在开发一个 silverlight 应用程序 我想在其中播放 youtube 视频 任何建议请 可供参考的任何示例或任何链接 提前致谢 这里有一个关于这个问题的有趣主题 其中包含 SL 3 0 beta 中的一些示例 http silv
  • 在简单的 main() 中获取rawinputdata

    我正在尝试使用简单的 C 技术和 Windows 从操纵杆读取值 我的目标是编写一个程序 每当操纵杆信号超过预定义阈值时 该程序就会发送键盘命令 键盘命令将由当时处于活动状态的窗口拾取 我的 C 编码技能有限 因此我希望以最简单的方式完成此
  • 如何将 Tomcat 重写阀添加到 Spring Boot 2.0 应用程序

    我正在尝试在 Spring Boot 应用程序中使用 Tomcat 重写阀 但是无法确定将 rewrite conf 放在哪里才能成功加载 我将 Spring Boot 2 0 3 RELEASE 与 Tomcat 8 5 31 一起使用
  • 以编程方式更改图像分辨率

    我计算过 如果我希望生成的图像为 A4 尺寸 600dpi 用于打印目的 则需要为 7016x4961px 72dpi 所以 我以编程方式生成它 然后在 Photoshop 中测试它 它似乎很好 所以如果我调整它的大小 它会获得正确的大小和
  • 如何让FlatList充满高度?

    import React from react import SafeAreaView KeyboardAvoidingView FlatList View Text TextInput Button StyleSheet from rea
  • 配置 grunt 复制任务以排除文件/文件夹

    我已经安装了 grunt 任务grunt contrib copy 我把它嵌入到我的Gruntfile js并通过加载任务grunt loadNpmTasks grunt contrib copy 目前 我使用以下配置来创建一个包含 js
  • 类在需要新实例的地方保留以前的内容

    我定义了一个类 以及一个创建该类实例的函数 我认为这个函数应该每次都创建一个新实例 然而 它看起来像是 继承 了上次调用的内容 任何人都可以解释一下吗 谢谢 class test a def b self x self a append x
  • iframe shimming 或 ie6(及更低版本)选择 z-index 错误

    嗯 不知道有没有人遇到过这个问题简要说明是关于 IE6 的任何
  • constexpr 和奇怪的错误

    我有 constexpr bool is concurrency selected const return ConcurrentGBx gt isChecked GBx is a groupbox with checkbox 我收到错误
  • Backbone.js 事件处理程序命名的最佳实践

    假设我在视图中有一个函数 当某种状态发生更改时会触发该函数 最好给它起什么名字 为什么 状态改变 状态改变 状态改变时 状态改变时 我个人更喜欢使用onEventName名称保持 DOM 事件处理程序的本机命名约定 Like myEleme
  • 如何将Javascript的window.find限制为特定的DIV?

    是否可以在 Safari Firefox Chrome 中使用 Javascript 在特定的 div 容器中搜索给定的文本字符串 我知道你可以使用window find str 搜索整个页面 但是否可以将搜索区域限制为仅在 div 内 T
  • IntelliJ Idea groovy.lang.GroovyRuntimeException:模块版本冲突

    我的 Maven 构建很好 并且能够从 cli 运行 groovy 但是 如果我尝试在 IntelliJ Idea 版本 15 社区版 中运行我的 groovy 类 则会出现以下错误 Exception in thread main jav
  • PyTorch 中的截断反向传播(代码检查)

    我正在尝试在 PyTorch 中实现随时间截断的反向传播 对于以下简单情况K1 K2 我下面有一个实现可以产生合理的输出 但我只是想确保它是正确的 当我在网上查找 TBTT 的 PyTorch 示例时 它们在分离隐藏状态 将梯度归零以及这些