【摘 要】 我们介绍了一个新的深度神经网络模型家族。在该模型中,我们并没有定义隐藏层的离散序列,而是使用神经网络对隐状态的导数进行了参数化,并使用黑盒微分方程求解器计算神经网络的输出。这些“连续深度” 的模型具有恒定的内存成本,这使其计算策略适应每个输入,并且可以明确地以数值精度换取速度。我们在“连续深度” 的残差网络和“连续时间”的隐变量模型中展示了这些性质。我们还构建了连续的归一化流,这是一种可以通过最大似然进行训练、且无需对数据维度进行分区或排序的生成式模型。对于训练,我们展示了在不访问内部计算的情况下,任意常微分方程求解的反向传播方法,这使大型模型能够对常微分方程进行端到端训练。

【原 文】 Chen, R.T.Q. et al. (2019) ‘Neural Ordinary Differential Equations’. arXiv. Available at: http://arxiv.org/abs/1806.07366 (Accessed: 15 November 2022).

1 常微分方程及其数值解

1.1 常微分方程问题

常微分方程是只包含单个自变量 $x$、未知函数 $h(x)$ 和未知函数导数 $h’(x)$ 的等式,例如 $h’(x) = 2x$ 就是一个常微分方程。但更常见的可以表示为 $dh(x)/dx = f(h(x), x)$,其中 $f(h(x), x)$ 表示由 $h(x)$ 和 $x$ 组成的表达式。对于常微分方程,一般的理解是希望得出 $h(x)$ 的解析表达式,例如 $h’(x) = 2x$ 的通解为 $h(x)=x^2 +C$,其中 $C$ 表示任意常数。但在工程中更常用的是数值解,即给定一个初值 $h(x_0)$,我们希望解出末值 $h(x_1)$,这样并不需要得出 $h(x)$ 的解析形式,只需要一步步逼近它就行了。

我们可以对初值为 $h(0)$ 的常微分方程作出如下形式化描述:

$$
\frac{dh(t)}{dt} = f(t, h(t))
$$

对 ODE 进行积分有

$$
h(t) = h(s) + \int_s^t f(x, h(x)) dx
$$

根据积分中值定理,可近似为

$$
h(t) = h(s) + (t - s) \cdot f(x, h(x)), x \in [s, t]
$$

进而可以写成递推式:

$$
h(t + \Delta t) = h(t) + \Delta t \cdot f(x, h(x))
$$

如果能估计出各步的 $f(x, h(x))$,由 $h(0)$ 开始就可以逐步推出 $h(T)$

此外,如果步长为负则得到反向的递推公式

$$
h(t - \Delta t) = h(t) - \Delta t \cdot f(x, h(x))
$$

可由 $h(T)$ 逐步推出 $h(0)$。

1.2 传统求解器举例

现有各种 ODE Solver 对 $f(x, h(x))$ 进行了不同的估计:

  1. Euler 法:$f(x, h(x)) \approx f(s, h(s))$
  2. 中点法(RK2):$f(x, h(x)) \approx \frac{1}{2} \cdot (f(s, h(s)) + f(t, h(s) + (t - s) \cdot f(s, h(s))))$

此外还有 Runge-Kutta (RK4) 和 Adams-Bashforth 等(参见计算数学相关文献)。

以 Euler 法为例,当步长 $\Delta t$ 较小时,可以用

$$
h(t + \Delta t) \approx h(t) + \Delta t \cdot f(t, h(t))
$$

从 $h(t)$ 的初值开始一步步估算出终值。

伪码如下(步长不必固定):

1
2
3
for dt in dts:
h += dt * f(t, h)
t += dt

2 问题的提出

现在回过头来讨论神经网络,本质上不论是全连接、循环还是卷积网络,它们都类似于一个非常复杂的复合函数,复合次数就等于层级的深度。例如两层全连接网络可以表示为 $Y=g(g(X, θ_1), θ_2)$,因此每一个神经网络层级都类似于万能函数逼近器。因为整体是复合函数,所以很容易接受复合函数的求导方法:链式法则,并将梯度从最外一层的函数一点点先向里面层级的函数传递,并且每传到一层函数,就可以更新该层的参数 $θ$。现在问题是,在前向传播过后需要保留所有层的激活值,并在沿计算路径反传梯度时利用这些激活值。这对内存的占用非常大,因此也就限制了深度模型的训练过程。

注意到残差网络、RNN 网络的 decoder、归一化流的模型与上述 Euler 法有类似的形式:

$$
h_{t+1} = h_{t} + f(h_t \thet\mathbf{a}_t)
$$

ODE Solver cluster_ode t=0 h0 h(0) dot . h0->dot ht h(T) f f * dt dot->f h, t plus + dot->plus f->plus t1 ... plus->t1 tt t=T t1->tt tt->ht
ODE Solver
两层 ResNet 的前向计算 cluster_l1 ResNet Layer 1 x x dot1 . x->dot1 y y conv1 conv dot1->conv1 plus1 + dot1->plus1 conv1->plus1 relu1 σ plus1->relu1 l2 ResNet Layer 2 relu1->l2 l2->y
两层 ResNet 的前向计算

如果假设这些离散的神经网络层($\Delta t=1$)之间还存在无穷多层,使得步长 $\Delta t$ 无穷小,以至于 $t$ 连续,则这些模型可写成 ODE 的形式:

$$
\frac{dh(t)}{dt} = f(h(t), t, \theta)
$$

其输入层是这个 ODE 的初值 $h(0)$,输出层就是解 $h(t)$ 在 $t=T$ 时的值 $h(T)$,模型本身可以被视为对初值为 $h(0)$ 的常微分方程定义的某个连续函数 $h(t)$ 的离散化。或者反过来理解,当我们添加更多层并采取更小的步骤时,可以使用神经网络定义的常微分方程 (ODE) 来参数化隐单元的连续动态:

ResNet and Neural ODE

将神经网络建模为 ODE 的好处:

  • 内存效率: 可以不对求解器的运算进行反向传播计算梯度,进而不用存储前向传递的任何中间量,解决了神经网络训练内存限制的瓶颈。
  • 计算的自适应: 高效准确的 ODE 求解器已经发展了 120 多年,并且现代 ODE 求解器会提供关于误差增长的保证、监测误差水平、动态调整计算策略以达到所要求的精度水平,这使得网络深度是自适应的,通过数值误差的阈值进行控制,可在精度与速度之间权衡。
  • 可扩展和可逆的归一化流: 连续转换的好处是变量变化公式更容易计算了,可以用它来构建新的可逆密度模型,进而避免了归一化流的瓶颈,并且可以直接通过最大似然进行训练。
  • 连续时间序列: 与需要离散化观测和发射间隔的递归神经网络不同,连续定义的动态可以自然地获得中间任意时刻的隐状态。

2 偏微分方程的反向自动微分

训练连续深度网络的主要技术难点是通过 ODE 求解器执行反向模式微分(也称为反向传播)。人们可以利用前向传播的运算进行微分,但会导致高内存成本并引入额外数值误差。 为此,我们将 ODE 求解器视为黑盒,并使用 伴随灵敏度方法 计算梯度(Pontryagin 等人,1962 年)。该方法通过 及时地 反向求解第二个 增广常微分方程 来计算梯度,适用于所有 ODE 求解器。这种方法与问题大小成线性关系,内存成本低,并明确控制数值误差。

考虑优化一个标量值的损失函数 $L(\cdot)$,其输入是某个 ODE 求解器的结果:

$$
L(\mathbf{z}(t_1)) = L \left( \mathbf{z}(t_0) + \int^{t_1}_{t_0} f(\mathbf{z}(t), t, θ)dt \right) = L( \text{ODESolve}(\mathbf{z}(t_0), f, t_0, t_1, θ))
$$

为了优化 $L$,我们需要关于 $θ$ 的梯度。第一步是确定损失对每个时刻的隐藏状态 $\mathbf{z}(t)$ 的梯度。这个量被称为伴随 $\mathbf{a}(t) = \partial L/\partial \mathbf{z}(t)$,其动态可以由另一个可以被视为链式法则瞬态模拟的 ODE 给出:

$$
\frac{d\mathbf{a}_{\mathbf{z}}(t)}{dt} = - \mathbf{a}(t)^T \frac{\partial f(\mathbf{z}(t),t,\theta)}{\partial \mathbf{z}}
$$

我们可以通过再次调用 ODE 求解器来计算 $\partial L/\partial \mathbf{z}(t_0)$。该求解器必须从 $\partial L/\partial \mathbf{z}(t_1)$ 的初始值开始反向运行。一个复杂问题是:求解此 ODE 需要知道 $\mathbf{z}(t)$ 沿其整个轨迹的值。不过,我们可以简单地从其终值 $\mathbf{z}(t_1)$ 开始重新计算 $\mathbf{z}(t)$ 及其伴随值。

计算关于参数 $θ$ 的梯度需要评估第三个积分,它取决于 $\mathbf{z}(t)$ 和 $\mathbf{a}(t)$:

$$
\frac{dL }{dθ} = − \int^{t_0}_{t1} \mathbf{a}(t)^T \frac{\partial f \left(\mathbf{z}(t), t, θ \right) }{\partial θ} dt
$$

在式 (4) 和式 (5) 中的 “向量-雅可比积” $\mathbf{a}(t)^T \frac{\partial f }{\partial z}$ 和 $\mathbf{a}(t)^T \frac{\partial f }{\partial θ}$ 可以通过自动微分有效地计算,其时间成本类似于计算 $f$。用于求解 $z$、$a$ 和 $\frac{\partial L }{\partial θ}$ 的所有积分都可以在对 ODE 求解器的一次调用中计算,它将原始状态、伴随和其他偏导数连接成一个向量。算法 1 展示了如何构建必要的动态,并调用 ODE 求解器一次计算所有梯度。

Algorithm01

大多数 ODE 求解器都可以选择多次输出状态 $\mathbf{z}(t)$。当损失取决于这些中间状态时,反向模式导数必须分解为一系列单独的求解,在每对连续的输出时间之间存在一个(图 2)。在每次观测时,必须在相应的偏导数 $\frac{\partial L}{\partial \mathbf{z}(t_i)}$ 的方向上调整伴随。

上述结果扩展了 Stapor 等(2018 年,第 2.4.2 节)的结果。算法 1 的扩展版本,包括相对于 $t_0$ 和 $t_1$ 的导数可以在 附录 C 中找到。详细的推导在 附录 B 中提供。附录 D 提供了 Python 代码,它通过扩展 autograd 自动微分包来计算 scipy.integrate.odeint 的所有导数。此代码还支持所有高阶导数。我们已经在 http://github.com/rtqichen/torchdiffeq 上发布了 PyTorch 实现(Paszke 等,2017 年),包括几个标准 ODE 求解器的基于 GPU 的实现。

title

$$
\begin{align*}
\frac{d \mathbf{a}_{\mathbf{z}}(t)}{dt}
&= - (\mathbf{a}_h(t), \mathbf{a}_\theta(t), \mathbf{a}_t(t)) \cdot \begin{pmatrix}
\frac{\partial {h’(t)}}{\partial {h(t)}} & \frac{\partial{h’(t)}}{\partial{\theta}} & \frac{\partial{h’(t)}}{\partial{t}} \\
0 & 0 & 0 \\
0 & 0 & 0
\end{pmatrix} \\
&= -( \mathbf{a}_h(t) \cdot \frac{\partial{h’(t)}}{\partial{h(t)}}, \mathbf{a}_h(t) \cdot \frac{\partial{h’(t)}}{\partial{\theta}}, \mathbf{a}_h(t) \cdot \frac{\partial{h’(t)}}{\partial{t}} )
\end{align*}
$$

于是得到了三个 ODE(原文构造了一个增广矩阵按一个 ODE 一起算,这里便于说明拆成了三个 ODE)
$$\begin{cases}
\frac{d\mathbf{a}_h(t)}{dt} = - \mathbf{a}_h(t) \cdot \frac{\partial{f(h(t), t, \theta)}}{\partial{h(t)}} \\
\frac{d\mathbf{a}_\thet\mathbf{a}(t)}{dt} = - \mathbf{a}_h(t) \cdot \frac{\partial{f(h(t), t, \theta)}}{\partial{\theta}} \\
\frac{d\mathbf{a}_t(t)}{dt} = - \mathbf{a}_h(t) \cdot \frac{\partial{f(h(t), t, \theta)}}{\partial{t}}
\end{cases}$$

对应的最终时刻的值

  1. $\mathbf{a}_h(T) = \frac{\partial{L}}{\partial{h(T)}} $ 可由损失函数的定义计算出
  2. 作者说“setting $\mathbf{a}_\thet\mathbf{a}(T) = 0$”没说为什么,我猜是因为优化问题最终就是要让损失函数相对参数的梯度为0
  3. $\mathbf{a}_t(T) = \frac{\partial{L}}{\partial{t}} |_T = \frac{\partial{L}}{\partial{h}} \cdot \frac{\partial{h}}{\partial{t}} |_T = \mathbf{a}_t(T) \cdot f(h(T), T, \theta)$

于是反向递推可得到损失函数相对于初始时刻的状态、模型参数和时间 t 的梯度 $\mathbf{a}_h(0), \mathbf{a}_\thet\mathbf{a}(0), \mathbf{a}_t(0)$。

以优化$\theta$为例,记$\alpha = \frac{\partial{L}}{\partial{h(t_N)}}$,$\beta$为学习率,因为
$$\frac{\partial{L}}{\partial{\thet\mathbf{a}(t_0)}} = \frac{\partial{L}}{\partial{\thet\mathbf{a}(t_N)}} - \int_{t_N}^{t_0} \alpha \cdot \frac{\partial{f(\theta)}}{\partial{\thet\mathbf{a}(t)}} dt$$
其中第一项按前面设置是0,$\alpha \cdot \frac{\partial{f(\theta)}}{\partial{\thet\mathbf{a}(t)}}$可以在使用f(*args).backward(alpha)后通过theta.grad得到,而且很多时候神经网络里的参数都是可以让输出值随便线性调整的(最后一层不是非线性),所以时间间隔具体是多少也不重要,省略了也可以。

3 在监督学习任务中用 ODE 取代残差网络

Neural ODE on MNIST

  • 用更少参数达到了 ResNet 相似的准确度
  • 达到了常数级的空间复杂度
    • RK-Net 是用 Runge-Kutta 前向估算 ODE,但优化时仍用的 BP 算法,所以内存还是高
  • 时间复杂度都跟网络深度一致,但 ODE-Net 的深度是自适应的,无法直接比较(但后面实验表明 ODE-Net 更快)

Error Control in ODE-Nets

  1. a 表明 ODE 的数值解的误差阈值与计算量成反比,而计算量与运行时间成正比(b),所以可以训练时减小误差阈值以提高精度,测试时放松阈值以提升速度;
  2. 反向计算梯度时计算量比前向计算还少一半,作者认为是反向传播时评估节点可能更少

4 连续型归一化流

5 生成时间序列模型

ODE Time-Series Model

参考资料