PyTorch 中的截断反向传播(代码检查)


我正在尝试在 PyTorch 中实现随时间截断的反向传播,对于以下简单情况K1=K2。我下面有一个实现可以产生合理的输出,但我只是想确保它是正确的。当我在网上查找 TBTT 的 PyTorch 示例时,它们在分离隐藏状态、将梯度归零以及这些操作的顺序方面做了不一致的事情。如果我犯了错误,请告诉我。

在下面的代码中,H保持当前的隐藏状态,并且model(weights, H, x)输出预测和新的隐藏状态。

while i < NUM_STEPS:
    # Grab x, y for ith datapoint
    x = data[i]
    target = true_output[i]

    # Run model
    output, new_hidden = model(weights, H, x)
    H = new_hidden

    # Update running error
    error += (output - target)**2

    if (i+1) % K == 0:
        # Backpropagate
        error = 0
        H = H.detach()

    i += 1

因此,代码的想法是在每个第 K 步之后隔离最后一个变量。是的,你的实现是绝对正确的,这answer证实了这一点。

# truncated to the last K timesteps
while i < NUM_STEPS:
    out = model(out)
    if (i+1) % K == 0:


import torch

from ignite.engine import Engine, EventEnum, _prepare_batch
from ignite.utils import apply_to_tensor

class Tbptt_Events(EventEnum):
    """Aditional tbptt events.

    Additional events for truncated backpropagation throught time dedicated

    TIME_ITERATION_STARTED = "time_iteration_started"
    TIME_ITERATION_COMPLETED = "time_iteration_completed"

def _detach_hidden(hidden):
    """Cut backpropagation graph.

    Auxillary function to cut the backpropagation graph by detaching the hidden
    return apply_to_tensor(hidden, torch.Tensor.detach)

def create_supervised_tbptt_trainer(
    model, optimizer, loss_fn, tbtt_step, dim=0, device=None, non_blocking=False, prepare_batch=_prepare_batch
    """Create a trainer for truncated backprop through time supervised models.

    Training recurrent model on long sequences is computationally intensive as
    it requires to process the whole sequence before getting a gradient.
    However, when the training loss is computed over many outputs
    (`X to many <>`_),
    there is an opportunity to compute a gradient over a subsequence. This is
    known as
    `truncated backpropagation through time <
    This supervised trainer apply gradient optimization step every `tbtt_step`
    time steps of the sequence, while backpropagating through the same
    `tbtt_step` time steps.

        model (`torch.nn.Module`): the model to train.
        optimizer (`torch.optim.Optimizer`): the optimizer to use.
        loss_fn (torch.nn loss function): the loss function to use.
        tbtt_step (int): the length of time chunks (last one may be smaller).
        dim (int): axis representing the time dimension.
        device (str, optional): device type specification (default: None).
            Applies to batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU,
            the copy may occur asynchronously with respect to the host. For other cases,
            this argument has no effect.
        prepare_batch (callable, optional): function that receives `batch`, `device`,
            `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`.

    .. warning::

        The internal use of `device` has changed.
        `device` will now *only* be used to move the input data to the correct device.
        The `model` should be moved by the user before creating an optimizer.

        For more information see:

        * `PyTorch Documentation <>`_
        * `PyTorch's Explanation <>`_

        Engine: a trainer engine with supervised update function.


    def _update(engine, batch):
        loss_list = []
        hidden = None

        x, y = batch
        for batch_t in zip(x.split(tbtt_step, dim=dim), y.split(tbtt_step, dim=dim)):
            x_t, y_t = prepare_batch(batch_t, device=device, non_blocking=non_blocking)
            # Fire event for start of iteration
            # Forward, backward and
            if hidden is None:
                y_pred_t, hidden = model(x_t)
                hidden = _detach_hidden(hidden)
                y_pred_t, hidden = model(x_t, hidden)
            loss_t = loss_fn(y_pred_t, y_t)

            # Setting state of engine for consistent behaviour
            engine.state.output = loss_t.item()

            # Fire event for end of iteration

        # return average loss over the time splits
        return sum(loss_list) / len(loss_list)

    engine = Engine(_update)
    return engine

