神经常微分方程
【摘 要】 我们介绍了一个新的深度神经网络模型家族。在该模型中,我们并没有定义隐藏层的离散序列,而是使用神经网络对隐状态的导数进行了参数化,并使用黑盒微分方程求解器计算神经网络的输出。这些“连续深度” 的模型具有恒定的内存成本,这使其计算策略适应每个输入,并且可以明确地以数值精度换取速度。我们在“连续深度” 的残差网络和“连续时间”的隐变量模型中展示了这些性质。我们还构建了连续的归一化流,这是一种可以通过最大似然进行训练、且无需对数据维度进行分区或排序的生成式模型。对于训练,我们展示了在不访问内部计算的情况下,任意常微分方程求解的反向传播方法,这使大型模型能够对常微分方程进行端到端训练。
【原 文】 Chen, R.T.Q. et al. (2019) ‘Neural Ordinary Differential Equations’. arXiv. Available at: http://arxiv.org/abs/1806.07366 (Accessed: 15 November 2022).
1 常微分方程及其数值解
1.1 常微分方程问题
常微分方程是只包含单个自变量 $x$、未知函数 $h(x)$ 和未知函数导数 $h’(x)$ 的等式,例如 $h’(x) = 2x$ 就是一个常微分方程。但更常见的可以表示为 $dh(x)/dx = f(h(x), x)$,其中 $f(h(x), x)$ 表示由 $h(x)$ 和 $x$ 组成的表达式。对于常微分方程,一般的理解是希望得出 $h(x)$ 的解析表达式,例如 $h’(x) = 2x$ 的通解为 $h(x)=x^2 +C$,其中 $C$ 表示任意常数。但在工程中更常用的是数值解,即给定一个初值 $h(x_0)$,我们希望解出末值 $h(x_1)$,这样并不需要得出 $h(x)$ 的解析形式,只需要一步步逼近它就行了。
我们可以对初值为 $h(0)$ 的常微分方程作出如下形式化描述:
$$
\frac{dh(t)}{dt} = f(t, h(t))
$$
对 ODE 进行积分有
$$
h(t) = h(s) + \int_s^t f(x, h(x)) dx
$$
根据积分中值定理,可近似为
$$
h(t) = h(s) + (t - s) \cdot f(x, h(x)), x \in [s, t]
$$
进而可以写成递推式:
$$
h(t + \Delta t) = h(t) + \Delta t \cdot f(x, h(x))
$$
如果能估计出各步的 $f(x, h(x))$,由 $h(0)$ 开始就可以逐步推出 $h(T)$
此外,如果步长为负则得到反向的递推公式
$$
h(t - \Delta t) = h(t) - \Delta t \cdot f(x, h(x))
$$
可由 $h(T)$ 逐步推出 $h(0)$。
1.2 传统求解器举例
现有各种 ODE Solver 对 $f(x, h(x))$ 进行了不同的估计:
- Euler 法:$f(x, h(x)) \approx f(s, h(s))$
- 中点法(RK2):$f(x, h(x)) \approx \frac{1}{2} \cdot (f(s, h(s)) + f(t, h(s) + (t - s) \cdot f(s, h(s))))$
此外还有 Runge-Kutta (RK4) 和 Adams-Bashforth 等(参见计算数学相关文献)。
以 Euler 法为例,当步长 $\Delta t$ 较小时,可以用
$$
h(t + \Delta t) \approx h(t) + \Delta t \cdot f(t, h(t))
$$
从 $h(t)$ 的初值开始一步步估算出终值。
伪码如下(步长不必固定):
1 | for dt in dts: |
2 问题的提出
现在回过头来讨论神经网络,本质上不论是全连接、循环还是卷积网络,它们都类似于一个非常复杂的复合函数,复合次数就等于层级的深度。例如两层全连接网络可以表示为 $Y=g(g(X, θ_1), θ_2)$,因此每一个神经网络层级都类似于万能函数逼近器。因为整体是复合函数,所以很容易接受复合函数的求导方法:链式法则,并将梯度从最外一层的函数一点点先向里面层级的函数传递,并且每传到一层函数,就可以更新该层的参数 $θ$。现在问题是,在前向传播过后需要保留所有层的激活值,并在沿计算路径反传梯度时利用这些激活值。这对内存的占用非常大,因此也就限制了深度模型的训练过程。
注意到残差网络、RNN 网络的 decoder、归一化流的模型与上述 Euler 法有类似的形式:
$$
h_{t+1} = h_{t} + f(h_t \thet\mathbf{a}_t)
$$
如果假设这些离散的神经网络层($\Delta t=1$)之间还存在无穷多层,使得步长 $\Delta t$ 无穷小,以至于 $t$ 连续,则这些模型可写成 ODE 的形式:
$$
\frac{dh(t)}{dt} = f(h(t), t, \theta)
$$
其输入层是这个 ODE 的初值 $h(0)$,输出层就是解 $h(t)$ 在 $t=T$ 时的值 $h(T)$,模型本身可以被视为对初值为 $h(0)$ 的常微分方程定义的某个连续函数 $h(t)$ 的离散化。或者反过来理解,当我们添加更多层并采取更小的步骤时,可以使用神经网络定义的常微分方程 (ODE) 来参数化隐单元的连续动态:
将神经网络建模为 ODE 的好处:
- 内存效率: 可以不对求解器的运算进行反向传播计算梯度,进而不用存储前向传递的任何中间量,解决了神经网络训练内存限制的瓶颈。
- 计算的自适应: 高效准确的 ODE 求解器已经发展了 120 多年,并且现代 ODE 求解器会提供关于误差增长的保证、监测误差水平、动态调整计算策略以达到所要求的精度水平,这使得网络深度是自适应的,通过数值误差的阈值进行控制,可在精度与速度之间权衡。
- 可扩展和可逆的归一化流: 连续转换的好处是变量变化公式更容易计算了,可以用它来构建新的可逆密度模型,进而避免了归一化流的瓶颈,并且可以直接通过最大似然进行训练。
- 连续时间序列: 与需要离散化观测和发射间隔的递归神经网络不同,连续定义的动态可以自然地获得中间任意时刻的隐状态。
2 偏微分方程的反向自动微分
训练连续深度网络的主要技术难点是通过 ODE 求解器执行反向模式微分(也称为反向传播)。人们可以利用前向传播的运算进行微分,但会导致高内存成本并引入额外数值误差。 为此,我们将 ODE 求解器视为黑盒,并使用 伴随灵敏度方法 计算梯度(Pontryagin 等人,1962 年)。该方法通过 及时地 反向求解第二个 增广常微分方程 来计算梯度,适用于所有 ODE 求解器。这种方法与问题大小成线性关系,内存成本低,并明确控制数值误差。
考虑优化一个标量值的损失函数 $L(\cdot)$,其输入是某个 ODE 求解器的结果:
$$
L(\mathbf{z}(t_1)) = L \left( \mathbf{z}(t_0) + \int^{t_1}_{t_0} f(\mathbf{z}(t), t, θ)dt \right) = L( \text{ODESolve}(\mathbf{z}(t_0), f, t_0, t_1, θ))
$$
为了优化 $L$,我们需要关于 $θ$ 的梯度。第一步是确定损失对每个时刻的隐藏状态 $\mathbf{z}(t)$ 的梯度。这个量被称为伴随 $\mathbf{a}(t) = \partial L/\partial \mathbf{z}(t)$,其动态可以由另一个可以被视为链式法则瞬态模拟的 ODE 给出:
$$
\frac{d\mathbf{a}_{\mathbf{z}}(t)}{dt} = - \mathbf{a}(t)^T \frac{\partial f(\mathbf{z}(t),t,\theta)}{\partial \mathbf{z}}
$$
我们可以通过再次调用 ODE 求解器来计算 $\partial L/\partial \mathbf{z}(t_0)$。该求解器必须从 $\partial L/\partial \mathbf{z}(t_1)$ 的初始值开始反向运行。一个复杂问题是:求解此 ODE 需要知道 $\mathbf{z}(t)$ 沿其整个轨迹的值。不过,我们可以简单地从其终值 $\mathbf{z}(t_1)$ 开始重新计算 $\mathbf{z}(t)$ 及其伴随值。
计算关于参数 $θ$ 的梯度需要评估第三个积分,它取决于 $\mathbf{z}(t)$ 和 $\mathbf{a}(t)$:
$$
\frac{dL }{dθ} = − \int^{t_0}_{t1} \mathbf{a}(t)^T \frac{\partial f \left(\mathbf{z}(t), t, θ \right) }{\partial θ} dt
$$
在式 (4) 和式 (5) 中的 “向量-雅可比积” $\mathbf{a}(t)^T \frac{\partial f }{\partial z}$ 和 $\mathbf{a}(t)^T \frac{\partial f }{\partial θ}$ 可以通过自动微分有效地计算,其时间成本类似于计算 $f$。用于求解 $z$、$a$ 和 $\frac{\partial L }{\partial θ}$ 的所有积分都可以在对 ODE 求解器的一次调用中计算,它将原始状态、伴随和其他偏导数连接成一个向量。算法 1
展示了如何构建必要的动态,并调用 ODE 求解器一次计算所有梯度。
大多数 ODE 求解器都可以选择多次输出状态 $\mathbf{z}(t)$。当损失取决于这些中间状态时,反向模式导数必须分解为一系列单独的求解,在每对连续的输出时间之间存在一个(图 2
)。在每次观测时,必须在相应的偏导数 $\frac{\partial L}{\partial \mathbf{z}(t_i)}$ 的方向上调整伴随。
上述结果扩展了 Stapor 等(2018 年,第 2.4.2 节)的结果。算法 1
的扩展版本,包括相对于 $t_0$ 和 $t_1$ 的导数可以在 附录 C
中找到。详细的推导在 附录 B
中提供。附录 D
提供了 Python 代码,它通过扩展 autograd
自动微分包来计算 scipy.integrate.odeint
的所有导数。此代码还支持所有高阶导数。我们已经在 http://github.com/rtqichen/torchdiffeq 上发布了 PyTorch 实现(Paszke 等,2017 年),包括几个标准 ODE 求解器的基于 GPU 的实现。
title
$$
\begin{align*}
\frac{d \mathbf{a}_{\mathbf{z}}(t)}{dt}
&= - (\mathbf{a}_h(t), \mathbf{a}_\theta(t), \mathbf{a}_t(t)) \cdot \begin{pmatrix}
\frac{\partial {h’(t)}}{\partial {h(t)}} & \frac{\partial{h’(t)}}{\partial{\theta}} & \frac{\partial{h’(t)}}{\partial{t}} \\
0 & 0 & 0 \\
0 & 0 & 0
\end{pmatrix} \\
&= -( \mathbf{a}_h(t) \cdot \frac{\partial{h’(t)}}{\partial{h(t)}}, \mathbf{a}_h(t) \cdot \frac{\partial{h’(t)}}{\partial{\theta}}, \mathbf{a}_h(t) \cdot \frac{\partial{h’(t)}}{\partial{t}} )
\end{align*}
$$
于是得到了三个 ODE(原文构造了一个增广矩阵按一个 ODE 一起算,这里便于说明拆成了三个 ODE)
$$\begin{cases}
\frac{d\mathbf{a}_h(t)}{dt} = - \mathbf{a}_h(t) \cdot \frac{\partial{f(h(t), t, \theta)}}{\partial{h(t)}} \\
\frac{d\mathbf{a}_\thet\mathbf{a}(t)}{dt} = - \mathbf{a}_h(t) \cdot \frac{\partial{f(h(t), t, \theta)}}{\partial{\theta}} \\
\frac{d\mathbf{a}_t(t)}{dt} = - \mathbf{a}_h(t) \cdot \frac{\partial{f(h(t), t, \theta)}}{\partial{t}}
\end{cases}$$
对应的最终时刻的值
- $\mathbf{a}_h(T) = \frac{\partial{L}}{\partial{h(T)}} $ 可由损失函数的定义计算出
- 作者说“setting $\mathbf{a}_\thet\mathbf{a}(T) = 0$”没说为什么,我猜是因为优化问题最终就是要让损失函数相对参数的梯度为0
- $\mathbf{a}_t(T) = \frac{\partial{L}}{\partial{t}} |_T = \frac{\partial{L}}{\partial{h}} \cdot \frac{\partial{h}}{\partial{t}} |_T = \mathbf{a}_t(T) \cdot f(h(T), T, \theta)$
于是反向递推可得到损失函数相对于初始时刻的状态、模型参数和时间 t 的梯度 $\mathbf{a}_h(0), \mathbf{a}_\thet\mathbf{a}(0), \mathbf{a}_t(0)$。
以优化$\theta$为例,记$\alpha = \frac{\partial{L}}{\partial{h(t_N)}}$,$\beta$为学习率,因为
$$\frac{\partial{L}}{\partial{\thet\mathbf{a}(t_0)}} = \frac{\partial{L}}{\partial{\thet\mathbf{a}(t_N)}} - \int_{t_N}^{t_0} \alpha \cdot \frac{\partial{f(\theta)}}{\partial{\thet\mathbf{a}(t)}} dt$$
其中第一项按前面设置是0,$\alpha \cdot \frac{\partial{f(\theta)}}{\partial{\thet\mathbf{a}(t)}}$可以在使用f(*args).backward(alpha)
后通过theta.grad
得到,而且很多时候神经网络里的参数都是可以让输出值随便线性调整的(最后一层不是非线性),所以时间间隔具体是多少也不重要,省略了也可以。
3 在监督学习任务中用 ODE 取代残差网络
- 用更少参数达到了 ResNet 相似的准确度
- 达到了常数级的空间复杂度
- RK-Net 是用 Runge-Kutta 前向估算 ODE,但优化时仍用的 BP 算法,所以内存还是高
- 时间复杂度都跟网络深度一致,但 ODE-Net 的深度是自适应的,无法直接比较(但后面实验表明 ODE-Net 更快)
- a 表明 ODE 的数值解的误差阈值与计算量成反比,而计算量与运行时间成正比(b),所以可以训练时减小误差阈值以提高精度,测试时放松阈值以提升速度;
- 反向计算梯度时计算量比前向计算还少一半,作者认为是反向传播时评估节点可能更少