Neural Ordinary Differential Equation 神经常微分方程(Neural ODEs)

2023-11-06

  用微分方程的视角来看待和理解神经网络是一种新的视角,该观点最早出现在2016年鄂维南院士的一篇proposal里:A Proposal on Machine Learning via Dynamical Systems.

Motivation

The core idea is that certain types of neural networks are analogous to a discretized differential equation, so maybe using off-the-shelf differential equation solvers will help get better results.

主要思想是:特定类型的神经网络可以看作离散的微分方程,所以使用现成的微分方程求解器可以帮助获得更好的结果。

First to see the contribution described in the original paper: “We introduce a new family of deep neural network models. Instead of specifying a discrete sequence of hidden layers, we parameterize the derivative of the hidden state using a neural network.”

先来看看原文中怎样描述这个贡献: “我们提出了一族新的神经网络模型…”。

不是指定一个离散序列,我们参数化了网络隐藏状态的导数。

Why we should to parameterize the derivative of the hidden state of the neural network? The answer is we should capture the characteristic of the middle layer of the neural network. Here, the derivative of the hidden layer is equal to the gradient in the backpropagate progress.

为什么参数化网络隐藏状态的导数,也就是中间层的导数,因为要建立隐藏状态的微分方程。中间层的导数不就是网络的梯度吗?

如果直接将中间层的结果求解出来,是否时避免了反向传播过程?


Reverse-mode automatic differentiation of ODE solutions

反向模式的自动微分ODE的解决方案

Let’s we show the result of the forward progress of neural network.
我们先来看NN(Neural Network)的前向过程:在这里插入图片描述
z ( t 1 ) z(t_1) z(t1) 代表 t 1 t_1 t1 时刻的隐藏状态(hidden state),而当隐藏状态被连续化后, t 0 t_0 t0 t 1 t_1 t1 时刻的中间隐藏状态的和就是等式中间部分的积分项。而整个前向过程可以用 ODE 求解器进行求解。注意,这里并没有定义 f f f 的具体形式,一个需要考虑的问题是:ODE solver 是否可以求解任意形式的 f f f。//todo

“The main technical difficulty in training continuous-depth networks is performing reverse-mode differentiation (also known as backpropagation) through the ODE solver.”

难点是使用 ODE solver 对连续的网络求解其反向模型的微分形式

We treat the ODE solver as a black box, and compute gradients using the adjoint sensitivity method. This approach computes gradients by solving a second, augmented ODE backwards in time, and is applicable to all ODE solvers. "

这里,将 ODE solver 看作是一个黑盒子,使用伴随敏感方法来求解梯度。该方法通过求解第二个、增强了的时间向后(时间轴反向)的 ODE 来计算梯度,而且所有 ODE solvers 都适用。具体过程为:

To optimize L L L, we require gradients with respect to θ \theta θ. The first step is to determining how the gradient of the loss depends on the hidden state z ( t ) z(t) z(t) at each instant. This quantity is called the adjoint a ( t ) = ∂ L ∂ z ( t ) a(t) =\frac{\partial L}{\partial z(t)} a(t)=z(t)L. Its dynamics are given by another ODE, which can be thought of as the instantaneous analog of the chain rule:

为了优化损失 L, 需要计算它对 θ \theta θ 的导数。第一步是怎样确定梯度依赖的隐层状态 z ( t ) z(t) z(t). 该性质称为 伴随。它的动态过程被另一个 ODE 来求解,可以把这种瞬时性被看作链式法则:
在这里插入图片描述(1)
该等式在1962年由 Pontryagin et al. 的论文《The mathematical theory of optimal processes》给出过证明,不过,本文作者也给出了相应的更简洁的证明过程:
  对于连续的隐层状态,可以将在时间上变化后的 ε \varepsilon ε 记作:
在这里插入图片描述(2)
上述公式说明,下一个状态 z z z 是关于上一个状态的函数(这里将参数 θ \theta θ 看作常量,具体的积分值由 f f f 决定)。 因此,相应的链式法则可以记作:
在这里插入图片描述(3)
由此,可以证明(1)式:
在这里插入图片描述
通过上述证明过程(引入 T ε ( z ( t ) ) T_{\varepsilon}(z(t)) Tε(z(t)) ,以说明 z ( t + ε ) z(t+\varepsilon) z(t+ε) z ( t ) z(t) z(t)的函数),第二步用到等式(3),另外对等式(2)进行泰勒展开( T ε T_{\varepsilon} Tε 中的 t t t 被隐含了),注意展开过程中的无穷小参数同样取 ε \varepsilon ε,然后就可以得到等式(1)。

  We specify the constraint on the last time point, which is simply the gradient of the loss wrt the last time point, and can obtain the gradients with respect to the hidden state at any time, including the initial value.

  这里就可以看出 ODE 沿时间的反向过程和 NN 中反向传播(BP)的相似性了。也就是通过 ODE 系统,前向和后向都是可以计算的。这里假设(限制)最后时刻( T N T_N TN)的隐层状态是已知的(可以直接通过 loss 的梯度获取),就可以求解任意时刻的隐层状态了(包括初始时刻):
在这里插入图片描述
  由此,整个 ODE 的反向过程的理论部分证明完成。

  这里引入了一个伴随状态(Adjoint State),它和前向状态相反,通过另一个 ODE 来求解。 关键是它们是怎样建立联系的?见下图:

在这里插入图片描述
  The adjoint sensitivity method solves an augmented ODE backwards in time. The augmented system contains both the original state and the sensitivity of the loss with respect to the state.
  伴随敏感度方法使用一个增强的在时间上反向的 ODE。该增强系统同时包括 原来的状态 a ( t ) a(t) a(t) 和损失对该状态的敏感度 ∂ L a ( t ) ∂ z ( t N ) \frac{\partial La(t)}{\partial {z(t_N)}} z(tN)La(t)。具体它俩是怎么计算的?
  答案是:由损失敏感度 ∂ L a ( t ) ∂ z ( t N ) \frac{\partial La(t)}{\partial {z(t_N)}} z(tN)La(t) 调节伴随(adjoint)状态 a ( t ) a(t) a(t), 然后再有伴随状态 a ( t ) a(t) a(t) 得到损失敏感度 ∂ L a ( t ) ∂ z ( t N ) \frac{\partial La(t)}{\partial {z(t_N)}} z(tN)La(t) 。这是 ODE 反向的链式过程。至此,整个反向传播的过程就被模拟了!

  Computing the gradients with respect to the parameters θ requires evaluating a third integral, which depends on both z(t) and a(t):
  计算关于 θ \theta θ 的梯度,还要计算相关变量 z(t) and a(t) 的积分:
在这里插入图片描述(4)

  通过等式(1)和(4)就可以计算出梯度了, a ( t ) T ∂ f ∂ z {a(t)}^T \frac{\partial f}{\partial z} a(t)Tzf a ( t ) T ∂ f ∂ θ {a(t)}^T \frac{\partial f}{\partial \theta} a(t)Tθf 的vector-Jacobian products 都可以通过 ODE solver 快速求解。 所有的积分解: z , a , ∂ L ∂ θ z, a, \frac{\partial L}{\partial \theta} z,a,θL 都可以通过一个 ODE solver 来求解,可以将它们组合成一个向量解 (增强的状态,augmented state)。具体步骤见算法 1:
在这里插入图片描述
该算法基本上是上述过程的综合。首先定义初始状态 s 0 s_0 s0,然后定义 增强状态,aug_dynamics,该状态包括 f ( z ( t ) , t , θ ) f(z(t),t,\theta) f(z(t),t,θ) a ( t ) T ∂ f ∂ z {a(t)}^T \frac{\partial f}{\partial z} a(t)Tzf a ( t ) T ∂ f ∂ θ {a(t)}^T \frac{\partial f}{\partial \theta} a(t)Tθf 的vector-Jacobian products(通过自动微分工具得到)。然后通过 ODE solver 求解前一时刻的隐层状态,敏感状态,和梯度。注意,这些都是合并起来的向量形式(算子形式的张量?)。最后,返回敏感状态(用以下一时刻计算敏感状态)和梯度(用以更新参数 θ \theta θ)。


Replacing residual networks with ODEs

将ResNets 换成 ODEs

Software: To solve ODE initial value problems numerically, we use the implicit Adams method implemented in LSODE and VODE and interfaced through the scipy.integrate package. Being an implicit method, it has better guarantees than explicit methods such as Runge-Kutta but requires solving a nonlinear optimization problem at every step.This setup makes direct backpropagation through the integrator difficult.

软件实现: 为了求解 ODE 的数值解, 作者使用 Adams (一种梯度优化方法)方法实现了 LSODE 和 VODE 的scipy.integrate 接口。 作为一种隐式方法,它比显式方法有较好的保证,如 Runge-Kutta 需要在每一步求解非线性优化问题。这种设置使得直接使用积分器求解反向传播是困难的。作者使用 Python 的自动微分方法实现了伴随敏感方法,并使用 Tensorflow 在GPU上实现了 隐层状态的动态和求导(从Fortran ODE Solver 调用,从 Python autograd 中调用)。

Model Architectures: We experiment with a small residual network which downsamples the input twice then applies 6 standard residual blocks He et al. (2016b), which are replaced by an ODESolve module in the ODE-Net variant. We also test a network with the same architecture but where gradients are backpropagated directly through a Runge-Kutta integrator, referred to as RK-Net.

  论文中用两个降采样和6个残差块的小型 ResNet 进行了实验,将残差块替换为ODESolve 模块就变成了 ODE-Net 变体。作者还使用相同的架构测试了使用 Runge-Kutta 积分器来反向传播梯度的 RK-Net。

  代码实现

首先看整体网络结构:

feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]

其中,ODEfunc 定义为:

class ODEfunc(nn.Module):

    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0 # number of forward ?

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out

  与 Residual Block 不同的是多加了一次 Batch Normalization,ODEfunc 中的卷积 ConcatConv2d 实现为:

class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t   # extract the first channel and multiply the time t.
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)

  可以看到 ConcatConv2d 和原来的 卷积方式 基本相同,只是在 前向过程中,添加了变量(variable) t t t , 其中,torch.ones_like 返回一个填充了标量值1的张量,其大小与之相同 input ,乘以 t t t 表示在 t t t 时刻。然后,将 t t t x x x 合并(concatenation)起来,然后作为卷积的输入。这里有个问题,为什么变量 t t t 的 size 是 feature size,难道是对每个feature position 做连续化?//TODO (这里 grad 的形状和feature size 的形状相同)。

  接下来就是 ODEBlock的定义:

class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol) # ODE forward
        return out[1]

  ODEBlock 中定义了 积分时间(integration time) t ∈ [ 0 , 1 ] t \in [0,1] t[0,1] ,然后在前向过程中传入 odeint 中,关键点是 odeint, 按上述算法中。这里 rtol 和 atol 是 容忍度(tolerance),即模型的精度设定。out[1] 是梯度(gradient) ∂ L ∂ θ \frac{\partial L}{\partial \theta} θL。这样我们求得了梯度。其中,odeint 的实现为:

def odeint(func, y0, t, rtol=1e-7, atol=1e-9, method=None, options=None):
      tensor_input, func, y0, t = _check_inputs(func, y0, t)

    if options is None:
        options = {}
    elif method is None:
        raise ValueError('cannot supply `options` without specifying `method`')

    if method is None:
        method = 'dopri5'

    solver = SOLVERS[method](func, y0, rtol=rtol, atol=atol, **options)
    solution = solver.integrate(t)

    if tensor_input:
        solution = solution[0]
    return solution

The goal of an ODE solver is to find a continuous trajectory satisfying the ODE that passes through the initial condition. Solves the initial value problem (IVP) for a non-stiff system of first order ODEs: ∂ y ∂ t = f ( t , y ) \frac{\partial y}{\partial t}=f(t,y) ty=f(t,y) s.t. y ( t 0 ) = y 0 y(t_0)=y_0 y(t0)=y0 where y is a Tensor of any shape.

odeint 解的是非复杂(non-stiff)系统的一阶 ODE 的初值问题 (IVP),其中,y是任意形状的张量。以下是其中参数的解释:

"""
	Args:
        func: Function that maps a Tensor holding the state `y` and a scalar Tensor
            `t` into a Tensor of state derivatives with respect to time.
    func:把一个含有状态张量 y 和常张量 t 映射到 一个关于时间可导的张量上。
        y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
            have any floating point or complex dtype.
    y0: NxD维度的张量,是 y 在 t[0] 的初始点,可以是任意复杂的类型。
        t: 1-D Tensor holding a sequence of time points for which to solve for
            `y`. The initial time point should be the first element of this sequence,
            and each time must be larger than the previous time. May have any floating
            point dtype. Converted to a Tensor with float64 dtype.
    t: 1xD的张量,表示一系列用于求解 y 的时间点。
        rtol: optional float64 Tensor specifying an upper bound on relative error,
            per element of `y`.
    rtol: 相对错误容忍度,以限制张量 y 中每个元素的上限值。(可调节)
        atol: optional float64 Tensor specifying an upper bound on absolute error,
            per element of `y`.
    atol: 绝对错误容忍度,以限制张量 y 中每个元素的上限值。(可调节)
        method: optional string indicating the integration method to use.
        method: 可选的string型 以决定那种 积分方法 被使用。
        options: optional dict of configuring options for the indicated integration
            method. Can only be provided if a `method` is explicitly set.
    options: 可选的字典类型,用于配置积分方法。
        name: Optional name for this operation.
	name:  为该操作指定名称。
    Returns:
        y: Tensor, where the first dimension corresponds to different
            time points. Contains the solved value of y for each desired time point in
            `t`, with the initial value `y0` being the first element along the first
            dimension.
    Returns: 返回第一个维度对应不同的时间点的 y 张量。
             包含 y 在每个时间点 t 上被期望的解。(所有时间点的解都被求得了),
             初始值 y0 是第一维度的第一个元素。
"""

  看一下 SOLOVE中的积分方法:

SOLVERS = {
    'explicit_adams': AdamsBashforth,
    'fixed_adams': AdamsBashforthMoulton,
    'adams': VariableCoefficientAdamsBashforth,
    'tsit5': Tsit5Solver,
    'dopri5': Dopri5Solver,
    'euler': Euler,
    'midpoint': Midpoint,
    'rk4': RK4,
}

  这里牵涉到微分方程的数值解法。这里 AdamsBashforth、AdamsBashforthMoulton、Euler、Midpoint、RK4 (Fourth-order Runge-Kutta with 3/8 rule) 属于 FixedGridODESolver (固定网格 ODE 求解器),其中,前两个 Adams 类型的求解器 是作者自己实现的 Adam梯度下降方法来求解的 FixedGridODESolver。而VariableCoefficientAdamsBashforth、Tsit5Solver ()、Dopri5Solver (Runge-Kutta 4(5))属于 AdaptiveStepsizeODESolver(自定义步长的 ODE 求解器)。论文中把 ODE solver 当作一个黑盒子(black box),我们知道它可以求解我们所需要的微分方程。这里只看最简单的 Euler 求解器:

class Euler(FixedGridODESolver):

    def step_func(self, func, t, dt, y):
        return tuple(dt * f_ for f_ in func(t, y))

  它只是实现了父类 FixedGridODESolver 中的 step_func,父类 FixedGridODESolver 的实现为:

class FixedGridODESolver(object):
	def __init__(self, func, y0, step_size=None, grid_constructor=None, **unused_kwargs):
		...
		... # here, I omit some initialize progress in origin code
		# and omit some grid constructor progress.
		 
	@abc.abstractmethod
    def step_func(self, func, t, dt, y):
        pass

    def integrate(self, t):
        _assert_increasing(t) # t is increase sequence
        t = t.type_as(self.y0[0])
        time_grid = self.grid_constructor(self.func, self.y0, t) # grad
        assert time_grid[0] == t[0] and time_grid[-1] == t[-1]
        time_grid = time_grid.to(self.y0[0])

        solution = [self.y0] # target solution list

        j = 1
        y0 = self.y0
        for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
            dy = self.step_func(self.func, t0, t1 - t0, y0) # use step function
            y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy)) # y1=y0+dy
            y0 = y1 # why to this?
			# linear interpolate the time sequence.
            while j < len(t) and t1 >= t[j]:
                solution.append(self._linear_interp(t0, t1, y0, y1, t[j]))
                j += 1

        return tuple(map(torch.stack, tuple(zip(*solution))))
        
    def _linear_interp(self, t0, t1, y0, y1, t):
        if t == t0:
            return y0
        if t == t1:
            return y1
        t0, t1, t = t0.to(y0[0]), t1.to(y0[0]), t.to(y0[0])
        slope = tuple((y1_ - y0_) / (t1 - t0) for y0_, y1_, in zip(y0, y1))
        return tuple(y0_ + slope_ * (t - t0) for y0_, slope_ in zip(y0, slope))

  这里的积分 应该是对 差分 的积分,即根据初始值 y 0 y_0 y0 和时间序列 t t t 来求 y t y_t yt。 首先构建 time grad,然后使用step_func,根据 func (NN 中的 f f f) 和 time grad 中的 t 以及 y 0 y_0 y0 来计算 d y dy dy, 接着,根据 y 1 = y 0 + d y y_1=y_0+dy y1=y0+dy 求得 y 1 y_1 y1, 这里有一行 y 0 = y 1 y_0=y_1 y0=y1, 为什么把y1赋值给 y 0 y_0 y0 ? 然后再根据 y 0 y_0 y0, y 1 y_1 y1 求插值 ?这样元素不就等于零了? //todo

  到这里,整个 ODE-Net的方法和实现都走一遍了,但我们好像只看到了前向过程?没有反向过程?这是因为 反向过程被 Pytorch 在内部自动实现了 (autograd backpropagate),并没有使用作者提出的 adjoint sensitivity method。作者指出使用 adjoint 方法可将 内存复杂度 降为 O ( 1 ) O(1) O(1)

  Backpropagation through odeint goes through the internals of the solver, but this is not supported for all solvers. Instead, we encourage the use of the adjoint method, which will allow solving with as many steps as necessary due to O(1) memory usage.

  odeint_adjoint simply wraps around odeint, but will use only O(1) memory in exchange for solving an adjoint ODE in the backward call. The biggest gotcha is that func must be a nn.Module when using the adjoint method. This is used to collect parameters of the differential equation.

odeint_adjoint 简单第封装了 odeint,并实现了反向过程。但其最大的缺憾(硬伤)是func f f f 的取值必须是 nn.Module 的方法,这是为了收集微分方程的参数。( Why must be collect parameters of the differential equation? The answer is use to backward of adjoint odeint.)看一下adjoint odeint 的实现过程:

def odeint_adjoint(func, y0, t, rtol=1e-6, atol=1e-12, method=None, options=None):

    # We need this in order to access the variables inside this module,
    # since we have no other way of getting variables along the execution path.
    if not isinstance(func, nn.Module):
        raise ValueError('func is required to be an instance of nn.Module.')

    tensor_input = False
    if torch.is_tensor(y0):

        class TupleFunc(nn.Module):

            def __init__(self, base_func):
                super(TupleFunc, self).__init__()
                self.base_func = base_func

            def forward(self, t, y):
                return (self.base_func(t, y[0]),)

        tensor_input = True
        y0 = (y0,)
        func = TupleFunc(func)

    flat_params = _flatten(func.parameters())
    ys = OdeintAdjointMethod.apply(*y0, func, t, flat_params, rtol, atol, method, options)

    if tensor_input:
        ys = ys[0]
    return ys

  首先说明了odeint_adjoint 的变量是有序的,然后通过内部类封装了一下 func,这里明确的限制了 func 是 nn.Module,这样 ODE-Net 的前向过程就实现了。接下来,通过 OdeintAdjointMethod 具体执行 ODE 的前向和反向过程:

class OdeintAdjointMethod(torch.autograd.Function):

    @staticmethod
    def forward(ctx, *args):
        assert len(args) >= 8, 'Internal error: all arguments required.'
        y0, func, t, flat_params, rtol, atol, method, options = \
            args[:-7], args[-7], args[-6], args[-5], args[-4], args[-3], args[-2], args[-1]

        ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options = func, rtol, atol, method, options

        with torch.no_grad():
            ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options)
        ctx.save_for_backward(t, flat_params, *ans)
        return ans

  前向过程很简单,通过继承 torch.autograd.Function,将一些参数赋值给 ctx(没有通过 self 实现,因为ctx只在forward过程中存在。通过 self 会不会更直观),并保存了 t t t,func 的参数 和 odeint 的前向结果,以便在反向过程中使用。再看其反向过程:

    @staticmethod
    def backward(ctx, *grad_output):

        t, flat_params, *ans = ctx.saved_tensors
        ans = tuple(ans)
        func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options
        n_tensors = len(ans)
        f_params = tuple(func.parameters())

        # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.
            y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            with torch.set_grad_enabled(True):
                t = t.to(y[0].device).detach().requires_grad_(True)
                y = tuple(y_.detach().requires_grad_(True) for y_ in y)
                func_eval = func(t, y)
                vjp_t, *vjp_y_and_params = torch.autograd.grad(
                    func_eval, (t,) + y + f_params,
                    tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True
                )
            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]

            # autograd.grad returns None if no gradient, set to zero.
            vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
            vjp_y = tuple(torch.zeros_like(y_) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y))
            vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)

            if len(f_params) == 0:
                vjp_params = torch.tensor(0.).to(vjp_y[0])
            return (*func_eval, *vjp_y, vjp_t, vjp_params)

        T = ans[0].shape[0]
        with torch.no_grad():
            adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
            adj_params = torch.zeros_like(flat_params)
            adj_time = torch.tensor(0.).to(t)
            time_vjps = []
            for i in range(T - 1, 0, -1):

                ans_i = tuple(ans_[i] for ans_ in ans)
                grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output)
                func_i = func(t[i], ans_i)

                # Compute the effect of moving the current time measurement point.
                dLd_cur_t = sum(
                    torch.dot(func_i_.reshape(-1), grad_output_i_.reshape(-1)).reshape(1)
                    for func_i_, grad_output_i_ in zip(func_i, grad_output_i)
                )
                adj_time = adj_time - dLd_cur_t
                time_vjps.append(dLd_cur_t)

                # Run the augmented system backwards in time.
                if adj_params.numel() == 0:
                    adj_params = torch.tensor(0.).to(adj_y[0])
                aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)
                aug_ans = odeint(
                    augmented_dynamics, aug_y0,
                    torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options
                )

                # Unpack aug_ans.
                adj_y = aug_ans[n_tensors:2 * n_tensors]
                adj_time = aug_ans[2 * n_tensors]
                adj_params = aug_ans[2 * n_tensors + 1]

                adj_y = tuple(adj_y_[1] if len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y)
                if len(adj_time) > 0: adj_time = adj_time[1]
                if len(adj_params) > 0: adj_params = adj_params[1]

                adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output))

                del aug_y0, aug_ans

            time_vjps.append(adj_time)
            time_vjps = torch.cat(time_vjps[::-1])

            return (*adj_y, None, time_vjps, adj_params, None, None, None, None, None)

  其中,torch.autograd.grad(outputs, inputs, grad_outputs=None, … ) 是用来计算输出对输入的梯度(Computes and returns the sum of gradients of outputs w.r.t. the inputs.)。这里需要用到 自动微分 中的知识。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Neural Ordinary Differential Equation 神经常微分方程(Neural ODEs) 的相关文章

  • JS实现一键回到顶部的功能(兼容所有浏览器,超级详细)

    我们在浏览网页的时候 大部分都有一个一键回到顶部的按钮 无论是pc端还是移动端 这个功能都很常见 我在一次面试的时候 也要求手写这个功能 首先我们新建一个空页面 把body的高度设置为3000px 这样做的目的是让浏览器出现滚动条 不然我们

随机推荐

  • 动态数组的实现

    public class MyArrayList
  • 栈与队列 数据结构 C语言

    目录 一 栈 1 类型定义 2 接口函数 3 功能实现 初始化栈 进栈 删除栈顶 出栈 销毁栈 其他功能 一 栈 先进后出 后进先出 1 类型定义 typedef int STDataType typedef struct Stack ST
  • 如何用Python进行大数据挖掘和分析?快速入门路径图!

    大数据无处不在 在时下这个年代 不管你喜欢与否 在运营一个成功的商业的过程中都有可能会遇到它 什么是 大数据 大数据就像它看起来那样 有大量的数据 单独而言 你能从单一的数据获取的洞见穷其有限 但是结合复杂数学模型以及强大计算能力的TB级数
  • main函数参数int main(int argc, char *argv[])解析

    main函数可以不带参数 也可以带参数 这个参数可以认为是 main函数的形式参数 C语言规定main函数的参数只能有两个 习惯上这两个参数写为argc和argv 所以C99标准中规定只有以下两种定义方式是正确的 int main void
  • Redis使用Zset做一个排行榜,当权值一样时,怎么按时间排序

    前言 zset是根据score进行排序 当score相同时 默认按照member的字典序进行排序 案例说明 127 0 0 1 6379 gt zadd t1 2 c 1 b 2 a integer 3 127 0 0 1 6379 gt
  • 关于pytorch的backward()

    pytorch中的loss backward 是梯度反传 计算每一个变量的grad 只是之前在纠结GAN的两个loss什么时候反传 参数什么时候更新的时候 观察到backward 后 内存的存储量下降 原来反传完毕之后 就把中间计算变量都释
  • VS code配置C语言,详细教程,初学者专用(附需要的插件)(win系统)

    vscode配置C语言首先下载vscode 这里我就不多说了 我们自己在使用vscode配置c语言后发现c语言根本就不能运行 是因为我们缺少一个配置c语言的插件需要我们自己下载 因为vscode不提供 这里是插件的链接 需要大家自己去提取
  • Windows10自带远程桌面连接Linux--CentOS的操作系统

    加粗样式看到网上好多的都是关于ubuntu类型的连接 或者就是自己在下个软件去连接Linux 而并非是用自带的 本文将为大家自动rdp去连接 1 默认库不包含xrdp 需要安装epel库 yum install epel release y
  • C++ 子类继承父类纯虚函数、虚函数和普通函数的区别

    C 三大特性 封装 继承 多态 今天给大家好好说说继承的奥妙 1 虚函数 C 的虚函数主要作用是 运行时多态 父类中提供虚函数的实现 为子类提供默认的函数实现 子类可以重写父类的虚函数实现子类的特殊化 2 纯虚函数 C 中包含纯虚函数的类
  • Class Not Found-Maven工程单元测试类报错

    很显然 Class Not Found已经说明了相关的类class文明招不到 这一点可以通过打开target目录的classes文件夹得到印证 该问题原因在在于Maven工程不会自动的为我们给java类进行编译 所以就导致了有时候我们jav
  • rocketMQ记录

    https segmentfault com a 1190000017841402 停止命令 sh bin mqshutdown namesrv sh bin mqshutdown broker
  • excel 计算 分位值

    XLFN QUARTILE EXC Result 1 G G 2 和 PERCENTILE 都可以用来计算一组数据的分位数 但是它们的计算方式略有不同 XLFN QUARTILE EXC Result 1 G G 2 是 Excel 中的一
  • AI之路(二)——关于统计学习(statistical learning)Part 1 概论

    从今日起 正式开启AI之路 在人工智能学习领域 无论机器学习还是深度学习 统计学习是入门的最好参考教材 是不可或缺的 因此 这漫漫求索之路 就从统计学习开始吧 我所选择的是李航所著的统计学习 第二版 计划将我对本书的自学总结或心得 能及时地
  • ES6函数新增了哪些扩展?

    目录 一 参数 二 属性 函数的length属性 name属性 三 作用域 四 严格模式 五 箭头函数 一 参数 ES6允许为函数的参数设置默认值 function log x y World console log x y console
  • [开发

    参考资料 安装部署 DataEase 本地源码启动 开发环境搭建 常见问题 编译源码时 dataease plugin 相关依赖无法拉取 DataEase 常见问题及解答 持续更新 社区论坛 下载依赖 使用如下命令下载依赖 mvn depe
  • Linux如何检测到僵尸进城,如何在linux下查看僵尸进程

    首先说说 僵尸进程是什么 僵尸进程是当子进程比父进程先结束 而父进程又没有回收子进程 释放子进程占用的资源 此时子进程将成为一个僵尸进程 如果父进程先退出 子进程被init接管 子进程退出后init会回收其占用的相关资源 我们都知道进程的工
  • BLE蓝牙协议 — 自适应调频算法简单实现

    写在前面 转载文章 若有不妥 通知后我会立即删除 最近看了大神刘权写的 BLE4 0低功耗蓝牙协议总结 感觉收获颇丰 其中有一节是讲解蓝牙的自适应调频算法的 但是代码实现不方便阅读 原文是这样的 小生不才 斗胆做了一下调整 还望大神海涵 下
  • html+css 热茶效果

  • ts获取服务器数据_ionic4中数据交互get post jsonp请求服务器数据

    ionic4 x中请求数据和angular中几乎是一样的 其中get post和和服务器交互使用的是HttpClientModule模块 下面我们看看ionic4中数据交互get post jsonp请求服务器数据 一 ionic4 x g
  • Neural Ordinary Differential Equation 神经常微分方程(Neural ODEs)

    用微分方程的视角来看待和理解神经网络是一种新的视角 该观点最早出现在2016年鄂维南院士的一篇proposal里 A Proposal on Machine Learning via Dynamical Systems Motivation