【摘 要】 我们介绍了一个新的深度神经网络模型家族。在该模型中,我们并没有定义隐藏层的离散序列,而是使用神经网络对隐状态的导数进行了参数化,并使用黑盒微分方程求解器计算神经网络的输出。这些“连续深度” 的模型具有恒定的内存成本,这使其计算策略适应每个输入,并且可以明确地以数值精度换取速度。我们在“连续深度” 的残差网络和“连续时间”的隐变量模型中展示了这些性质。我们还构建了连续的归一化流,这是一种可以通过最大似然进行训练、且无需对数据维度进行分区或排序的生成式模型。对于训练,我们展示了在不访问内部计算的情况下,任意常微分方程求解的反向传播方法,这使大型模型能够对常微分方程进行端到端训练。

【原 文】 Chen, R.T.Q. et al. (2019) ‘Neural Ordinary Differential Equations’. arXiv. Available at: http://arxiv.org/abs/1806.07366 (Accessed: 15 November 2022).

1 常微分方程及其数值解

1.1 常微分方程问题

常微分方程是只包含单个自变量 xx、未知函数 h(x)h(x) 和未知函数导数 h(x)h'(x) 的等式,例如 h(x)=2xh'(x) = 2x 就是一个常微分方程。但更常见的可以表示为 dh(x)/dx=f(h(x),x)dh(x)/dx = f(h(x), x),其中 f(h(x),x)f(h(x), x) 表示由 h(x)h(x)xx 组成的表达式。对于常微分方程,一般的理解是希望得出 h(x)h(x) 的解析表达式,例如 h(x)=2xh'(x) = 2x 的通解为 h(x)=x2+Ch(x)=x^2 +C,其中 CC 表示任意常数。但在工程中更常用的是数值解,即给定一个初值 h(x0)h(x_0),我们希望解出末值 h(x1)h(x_1),这样并不需要得出 h(x)h(x) 的解析形式,只需要一步步逼近它就行了。

我们可以对初值为 h(0)h(0) 的常微分方程作出如下形式化描述:

dh(t)dt=f(t,h(t))\frac{dh(t)}{dt} = f(t, h(t))

对 ODE 进行积分有

h(t)=h(s)+stf(x,h(x))dxh(t) = h(s) + \int_s^t f(x, h(x)) dx

根据积分中值定理,可近似为

h(t)=h(s)+(ts)f(x,h(x)),x[s,t]h(t) = h(s) + (t - s) \cdot f(x, h(x)), x \in [s, t]

进而可以写成递推式:

h(t+Δt)=h(t)+Δtf(x,h(x))h(t + \Delta t) = h(t) + \Delta t \cdot f(x, h(x))

如果能估计出各步的 f(x,h(x))f(x, h(x)),由 h(0)h(0) 开始就可以逐步推出 h(T)h(T)

此外,如果步长为负则得到反向的递推公式

h(tΔt)=h(t)Δtf(x,h(x))h(t - \Delta t) = h(t) - \Delta t \cdot f(x, h(x))

可由 h(T)h(T) 逐步推出 h(0)h(0)

1.2 传统求解器举例

现有各种 ODE Solver 对 f(x,h(x))f(x, h(x)) 进行了不同的估计:

  1. Euler 法:f(x,h(x))f(s,h(s))f(x, h(x)) \approx f(s, h(s))
  2. 中点法(RK2):f(x,h(x))12(f(s,h(s))+f(t,h(s)+(ts)f(s,h(s))))f(x, h(x)) \approx \frac{1}{2} \cdot (f(s, h(s)) + f(t, h(s) + (t - s) \cdot f(s, h(s))))

此外还有 Runge-Kutta (RK4) 和 Adams-Bashforth 等(参见计算数学相关文献)。

以 Euler 法为例,当步长 Δt\Delta t 较小时,可以用

h(t+Δt)h(t)+Δtf(t,h(t))h(t + \Delta t) \approx h(t) + \Delta t \cdot f(t, h(t))

h(t)h(t) 的初值开始一步步估算出终值。

伪码如下(步长不必固定):

1
2
3
for dt in dts:
h += dt * f(t, h)
t += dt

2 问题的提出

现在回过头来讨论神经网络,本质上不论是全连接、循环还是卷积网络,它们都类似于一个非常复杂的复合函数,复合次数就等于层级的深度。例如两层全连接网络可以表示为 Y=g(g(X,θ1),θ2)Y=g(g(X, θ_1), θ_2),因此每一个神经网络层级都类似于万能函数逼近器。因为整体是复合函数,所以很容易接受复合函数的求导方法:链式法则,并将梯度从最外一层的函数一点点先向里面层级的函数传递,并且每传到一层函数,就可以更新该层的参数 θθ。现在问题是,在前向传播过后需要保留所有层的激活值,并在沿计算路径反传梯度时利用这些激活值。这对内存的占用非常大,因此也就限制了深度模型的训练过程。

注意到残差网络、RNN 网络的 decoder、归一化流的模型与上述 Euler 法有类似的形式:

h_{t+1} = h_{t} + f(h_t \thet\mathbf{a}_t)

ODE Solver cluster_ode t=0 h0 h(0) dot . h0->dot ht h(T) f f * dt dot->f h, t plus + dot->plus f->plus t1 ... plus->t1 tt t=T t1->tt tt->ht
ODE Solver
两层 ResNet 的前向计算 cluster_l1 ResNet Layer 1 x x dot1 . x->dot1 y y conv1 conv dot1->conv1 plus1 + dot1->plus1 conv1->plus1 relu1 σ plus1->relu1 l2 ResNet Layer 2 relu1->l2 l2->y
两层 ResNet 的前向计算

如果假设这些离散的神经网络层(Δt=1\Delta t=1)之间还存在无穷多层,使得步长 Δt\Delta t 无穷小,以至于 tt 连续,则这些模型可写成 ODE 的形式:

dh(t)dt=f(h(t),t,θ)\frac{dh(t)}{dt} = f(h(t), t, \theta)

其输入层是这个 ODE 的初值 h(0)h(0),输出层就是解 h(t)h(t)t=Tt=T 时的值 h(T)h(T),模型本身可以被视为对初值为 h(0)h(0) 的常微分方程定义的某个连续函数 h(t)h(t) 的离散化。或者反过来理解,当我们添加更多层并采取更小的步骤时,可以使用神经网络定义的常微分方程 (ODE) 来参数化隐单元的连续动态:

ResNet and Neural ODE

将神经网络建模为 ODE 的好处:

  • 内存效率: 可以不对求解器的运算进行反向传播计算梯度,进而不用存储前向传递的任何中间量,解决了神经网络训练内存限制的瓶颈。
  • 计算的自适应: 高效准确的 ODE 求解器已经发展了 120 多年,并且现代 ODE 求解器会提供关于误差增长的保证、监测误差水平、动态调整计算策略以达到所要求的精度水平,这使得网络深度是自适应的,通过数值误差的阈值进行控制,可在精度与速度之间权衡。
  • 可扩展和可逆的归一化流: 连续转换的好处是变量变化公式更容易计算了,可以用它来构建新的可逆密度模型,进而避免了归一化流的瓶颈,并且可以直接通过最大似然进行训练。
  • 连续时间序列: 与需要离散化观测和发射间隔的递归神经网络不同,连续定义的动态可以自然地获得中间任意时刻的隐状态。

2 偏微分方程的反向自动微分

训练连续深度网络的主要技术难点是通过 ODE 求解器执行反向模式微分(也称为反向传播)。人们可以利用前向传播的运算进行微分,但会导致高内存成本并引入额外数值误差。 为此,我们将 ODE 求解器视为黑盒,并使用 伴随灵敏度方法 计算梯度(Pontryagin 等人,1962 年)。该方法通过 及时地 反向求解第二个 增广常微分方程 来计算梯度,适用于所有 ODE 求解器。这种方法与问题大小成线性关系,内存成本低,并明确控制数值误差。

考虑优化一个标量值的损失函数 L()L(\cdot),其输入是某个 ODE 求解器的结果:

L(z(t1))=L(z(t0)+t0t1f(z(t),t,θ)dt)=L(ODESolve(z(t0),f,t0,t1,θ))L(\mathbf{z}(t_1)) = L \left( \mathbf{z}(t_0) + \int^{t_1}_{t_0} f(\mathbf{z}(t), t, θ)dt \right) = L( \text{ODESolve}(\mathbf{z}(t_0), f, t_0, t_1, θ))

为了优化 LL,我们需要关于 θθ 的梯度。第一步是确定损失对每个时刻的隐藏状态 z(t)\mathbf{z}(t) 的梯度。这个量被称为伴随 a(t)=L/z(t)\mathbf{a}(t) = \partial L/\partial \mathbf{z}(t),其动态可以由另一个可以被视为链式法则瞬态模拟的 ODE 给出:

daz(t)dt=a(t)Tf(z(t),t,θ)z\frac{d\mathbf{a}_{\mathbf{z}}(t)}{dt} = - \mathbf{a}(t)^T \frac{\partial f(\mathbf{z}(t),t,\theta)}{\partial \mathbf{z}}

我们可以通过再次调用 ODE 求解器来计算 L/z(t0)\partial L/\partial \mathbf{z}(t_0)。该求解器必须从 L/z(t1)\partial L/\partial \mathbf{z}(t_1) 的初始值开始反向运行。一个复杂问题是:求解此 ODE 需要知道 z(t)\mathbf{z}(t) 沿其整个轨迹的值。不过,我们可以简单地从其终值 z(t1)\mathbf{z}(t_1) 开始重新计算 z(t)\mathbf{z}(t) 及其伴随值。

计算关于参数 θθ 的梯度需要评估第三个积分,它取决于 z(t)\mathbf{z}(t)a(t)\mathbf{a}(t)

dLdθ=t1t0a(t)Tf(z(t),t,θ)θdt\frac{dL }{dθ} = − \int^{t_0}_{t1} \mathbf{a}(t)^T \frac{\partial f \left(\mathbf{z}(t), t, θ \right) }{\partial θ} dt

在式 (4) 和式 (5) 中的 “向量-雅可比积” a(t)Tfz\mathbf{a}(t)^T \frac{\partial f }{\partial z}a(t)Tfθ\mathbf{a}(t)^T \frac{\partial f }{\partial θ} 可以通过自动微分有效地计算,其时间成本类似于计算 ff。用于求解 zzaaLθ\frac{\partial L }{\partial θ} 的所有积分都可以在对 ODE 求解器的一次调用中计算,它将原始状态、伴随和其他偏导数连接成一个向量。算法 1 展示了如何构建必要的动态,并调用 ODE 求解器一次计算所有梯度。

Algorithm01

大多数 ODE 求解器都可以选择多次输出状态 z(t)\mathbf{z}(t)。当损失取决于这些中间状态时,反向模式导数必须分解为一系列单独的求解,在每对连续的输出时间之间存在一个(图 2)。在每次观测时,必须在相应的偏导数 Lz(ti)\frac{\partial L}{\partial \mathbf{z}(t_i)} 的方向上调整伴随。

上述结果扩展了 Stapor 等(2018 年,第 2.4.2 节)的结果。算法 1 的扩展版本,包括相对于 t0t_0t1t_1 的导数可以在 附录 C 中找到。详细的推导在 附录 B 中提供。附录 D 提供了 Python 代码,它通过扩展 autograd 自动微分包来计算 scipy.integrate.odeint 的所有导数。此代码还支持所有高阶导数。我们已经在 http://github.com/rtqichen/torchdiffeq 上发布了 PyTorch 实现(Paszke 等,2017 年),包括几个标准 ODE 求解器的基于 GPU 的实现。

title

daz(t)dt=(ah(t),aθ(t),at(t))(h(t)h(t)h(t)θh(t)t000000)=(ah(t)h(t)h(t),ah(t)h(t)θ,ah(t)h(t)t)\begin{align*} \frac{d \mathbf{a}_{\mathbf{z}}(t)}{dt} &= - (\mathbf{a}_h(t), \mathbf{a}_\theta(t), \mathbf{a}_t(t)) \cdot \begin{pmatrix} \frac{\partial {h'(t)}}{\partial {h(t)}} & \frac{\partial{h'(t)}}{\partial{\theta}} & \frac{\partial{h'(t)}}{\partial{t}} \\\\ 0 & 0 & 0 \\\\ 0 & 0 & 0 \end{pmatrix} \\\\ &= -( \mathbf{a}_h(t) \cdot \frac{\partial{h'(t)}}{\partial{h(t)}}, \mathbf{a}_h(t) \cdot \frac{\partial{h'(t)}}{\partial{\theta}}, \mathbf{a}_h(t) \cdot \frac{\partial{h'(t)}}{\partial{t}} ) \end{align*}

于是得到了三个 ODE(原文构造了一个增广矩阵按一个 ODE 一起算,这里便于说明拆成了三个 ODE)

\begin{cases} \frac{d\mathbf{a}_h(t)}{dt} = - \mathbf{a}_h(t) \cdot \frac{\partial{f(h(t), t, \theta)}}{\partial{h(t)}} \\\\ \frac{d\mathbf{a}_\thet\mathbf{a}(t)}{dt} = - \mathbf{a}_h(t) \cdot \frac{\partial{f(h(t), t, \theta)}}{\partial{\theta}} \\\\ \frac{d\mathbf{a}_t(t)}{dt} = - \mathbf{a}_h(t) \cdot \frac{\partial{f(h(t), t, \theta)}}{\partial{t}} \end{cases}

对应的最终时刻的值

  1. $\mathbf{a}_h(T) = \frac{\partial{L}}{\partial{h(T)}} $ 可由损失函数的定义计算出
  2. 作者说“setting \mathbf{a}_\thet\mathbf{a}(T) = 0”没说为什么,我猜是因为优化问题最终就是要让损失函数相对参数的梯度为0
  3. at(T)=LtT=LhhtT=at(T)f(h(T),T,θ)\mathbf{a}_t(T) = \frac{\partial{L}}{\partial{t}} |_T = \frac{\partial{L}}{\partial{h}} \cdot \frac{\partial{h}}{\partial{t}} |_T = \mathbf{a}_t(T) \cdot f(h(T), T, \theta)

于是反向递推可得到损失函数相对于初始时刻的状态、模型参数和时间 t 的梯度 \mathbf{a}_h(0), \mathbf{a}_\thet\mathbf{a}(0), \mathbf{a}_t(0)

以优化θ\theta为例,记α=Lh(tN)\alpha = \frac{\partial{L}}{\partial{h(t_N)}}β\beta为学习率,因为

\frac{\partial{L}}{\partial{\thet\mathbf{a}(t_0)}} = \frac{\partial{L}}{\partial{\thet\mathbf{a}(t_N)}} - \int_{t_N}^{t_0} \alpha \cdot \frac{\partial{f(\theta)}}{\partial{\thet\mathbf{a}(t)}} dt

其中第一项按前面设置是0,\alpha \cdot \frac{\partial{f(\theta)}}{\partial{\thet\mathbf{a}(t)}}可以在使用f(*args).backward(alpha)后通过theta.grad得到,而且很多时候神经网络里的参数都是可以让输出值随便线性调整的(最后一层不是非线性),所以时间间隔具体是多少也不重要,省略了也可以。

3 在监督学习任务中用 ODE 取代残差网络

Neural ODE on MNIST

  • 用更少参数达到了 ResNet 相似的准确度
  • 达到了常数级的空间复杂度
    • RK-Net 是用 Runge-Kutta 前向估算 ODE,但优化时仍用的 BP 算法,所以内存还是高
  • 时间复杂度都跟网络深度一致,但 ODE-Net 的深度是自适应的,无法直接比较(但后面实验表明 ODE-Net 更快)

Error Control in ODE-Nets

  1. a 表明 ODE 的数值解的误差阈值与计算量成反比,而计算量与运行时间成正比(b),所以可以训练时减小误差阈值以提高精度,测试时放松阈值以提升速度;
  2. 反向计算梯度时计算量比前向计算还少一半,作者认为是反向传播时评估节点可能更少

4 连续型归一化流

5 生成时间序列模型

ODE Time-Series Model

参考资料