神经常微分方程
【摘 要】 我们介绍了一个新的深度神经网络模型家族。在该模型中,我们并没有定义隐藏层的离散序列,而是使用神经网络对隐状态的导数进行了参数化,并使用黑盒微分方程求解器计算神经网络的输出。这些“连续深度” 的模型具有恒定的内存成本,这使其计算策略适应每个输入,并且可以明确地以数值精度换取速度。我们在“连续深度” 的残差网络和“连续时间”的隐变量模型中展示了这些性质。我们还构建了连续的归一化流,这是一种可以通过最大似然进行训练、且无需对数据维度进行分区或排序的生成式模型。对于训练,我们展示了在不访问内部计算的情况下,任意常微分方程求解的反向传播方法,这使大型模型能够对常微分方程进行端到端训练。
【原 文】 Chen, R.T.Q. et al. (2019) ‘Neural Ordinary Differential Equations’. arXiv. Available at: http://arxiv.org/abs/1806.07366 (Accessed: 15 November 2022).
1 常微分方程及其数值解
1.1 常微分方程问题
常微分方程是只包含单个自变量 、未知函数 和未知函数导数 的等式,例如 就是一个常微分方程。但更常见的可以表示为 ,其中 表示由 和 组成的表达式。对于常微分方程,一般的理解是希望得出 的解析表达式,例如 的通解为 ,其中 表示任意常数。但在工程中更常用的是数值解,即给定一个初值 ,我们希望解出末值 ,这样并不需要得出 的解析形式,只需要一步步逼近它就行了。
我们可以对初值为 的常微分方程作出如下形式化描述:
对 ODE 进行积分有
根据积分中值定理,可近似为
进而可以写成递推式:
如果能估计出各步的 ,由 开始就可以逐步推出
此外,如果步长为负则得到反向的递推公式
可由 逐步推出 。
1.2 传统求解器举例
现有各种 ODE Solver 对 进行了不同的估计:
- Euler 法:
- 中点法(RK2):
此外还有 Runge-Kutta (RK4) 和 Adams-Bashforth 等(参见计算数学相关文献)。
以 Euler 法为例,当步长 较小时,可以用
从 的初值开始一步步估算出终值。
伪码如下(步长不必固定):
1 | for dt in dts: |
2 问题的提出
现在回过头来讨论神经网络,本质上不论是全连接、循环还是卷积网络,它们都类似于一个非常复杂的复合函数,复合次数就等于层级的深度。例如两层全连接网络可以表示为 ,因此每一个神经网络层级都类似于万能函数逼近器。因为整体是复合函数,所以很容易接受复合函数的求导方法:链式法则,并将梯度从最外一层的函数一点点先向里面层级的函数传递,并且每传到一层函数,就可以更新该层的参数 。现在问题是,在前向传播过后需要保留所有层的激活值,并在沿计算路径反传梯度时利用这些激活值。这对内存的占用非常大,因此也就限制了深度模型的训练过程。
注意到残差网络、RNN 网络的 decoder、归一化流的模型与上述 Euler 法有类似的形式:
h_{t+1} = h_{t} + f(h_t \thet\mathbf{a}_t)
如果假设这些离散的神经网络层()之间还存在无穷多层,使得步长 无穷小,以至于 连续,则这些模型可写成 ODE 的形式:
其输入层是这个 ODE 的初值 ,输出层就是解 在 时的值 ,模型本身可以被视为对初值为 的常微分方程定义的某个连续函数 的离散化。或者反过来理解,当我们添加更多层并采取更小的步骤时,可以使用神经网络定义的常微分方程 (ODE) 来参数化隐单元的连续动态:
将神经网络建模为 ODE 的好处:
- 内存效率: 可以不对求解器的运算进行反向传播计算梯度,进而不用存储前向传递的任何中间量,解决了神经网络训练内存限制的瓶颈。
- 计算的自适应: 高效准确的 ODE 求解器已经发展了 120 多年,并且现代 ODE 求解器会提供关于误差增长的保证、监测误差水平、动态调整计算策略以达到所要求的精度水平,这使得网络深度是自适应的,通过数值误差的阈值进行控制,可在精度与速度之间权衡。
- 可扩展和可逆的归一化流: 连续转换的好处是变量变化公式更容易计算了,可以用它来构建新的可逆密度模型,进而避免了归一化流的瓶颈,并且可以直接通过最大似然进行训练。
- 连续时间序列: 与需要离散化观测和发射间隔的递归神经网络不同,连续定义的动态可以自然地获得中间任意时刻的隐状态。
2 偏微分方程的反向自动微分
训练连续深度网络的主要技术难点是通过 ODE 求解器执行反向模式微分(也称为反向传播)。人们可以利用前向传播的运算进行微分,但会导致高内存成本并引入额外数值误差。 为此,我们将 ODE 求解器视为黑盒,并使用 伴随灵敏度方法 计算梯度(Pontryagin 等人,1962 年)。该方法通过 及时地 反向求解第二个 增广常微分方程 来计算梯度,适用于所有 ODE 求解器。这种方法与问题大小成线性关系,内存成本低,并明确控制数值误差。
考虑优化一个标量值的损失函数 ,其输入是某个 ODE 求解器的结果:
为了优化 ,我们需要关于 的梯度。第一步是确定损失对每个时刻的隐藏状态 的梯度。这个量被称为伴随 ,其动态可以由另一个可以被视为链式法则瞬态模拟的 ODE 给出:
我们可以通过再次调用 ODE 求解器来计算 。该求解器必须从 的初始值开始反向运行。一个复杂问题是:求解此 ODE 需要知道 沿其整个轨迹的值。不过,我们可以简单地从其终值 开始重新计算 及其伴随值。
计算关于参数 的梯度需要评估第三个积分,它取决于 和 :
在式 (4) 和式 (5) 中的 “向量-雅可比积” 和 可以通过自动微分有效地计算,其时间成本类似于计算 。用于求解 、 和 的所有积分都可以在对 ODE 求解器的一次调用中计算,它将原始状态、伴随和其他偏导数连接成一个向量。算法 1
展示了如何构建必要的动态,并调用 ODE 求解器一次计算所有梯度。
大多数 ODE 求解器都可以选择多次输出状态 。当损失取决于这些中间状态时,反向模式导数必须分解为一系列单独的求解,在每对连续的输出时间之间存在一个(图 2
)。在每次观测时,必须在相应的偏导数 的方向上调整伴随。
上述结果扩展了 Stapor 等(2018 年,第 2.4.2 节)的结果。算法 1
的扩展版本,包括相对于 和 的导数可以在 附录 C
中找到。详细的推导在 附录 B
中提供。附录 D
提供了 Python 代码,它通过扩展 autograd
自动微分包来计算 scipy.integrate.odeint
的所有导数。此代码还支持所有高阶导数。我们已经在 http://github.com/rtqichen/torchdiffeq 上发布了 PyTorch 实现(Paszke 等,2017 年),包括几个标准 ODE 求解器的基于 GPU 的实现。
title
于是得到了三个 ODE(原文构造了一个增广矩阵按一个 ODE 一起算,这里便于说明拆成了三个 ODE)
\begin{cases} \frac{d\mathbf{a}_h(t)}{dt} = - \mathbf{a}_h(t) \cdot \frac{\partial{f(h(t), t, \theta)}}{\partial{h(t)}} \\\\ \frac{d\mathbf{a}_\thet\mathbf{a}(t)}{dt} = - \mathbf{a}_h(t) \cdot \frac{\partial{f(h(t), t, \theta)}}{\partial{\theta}} \\\\ \frac{d\mathbf{a}_t(t)}{dt} = - \mathbf{a}_h(t) \cdot \frac{\partial{f(h(t), t, \theta)}}{\partial{t}} \end{cases}
对应的最终时刻的值
- $\mathbf{a}_h(T) = \frac{\partial{L}}{\partial{h(T)}} $ 可由损失函数的定义计算出
- 作者说“setting \mathbf{a}_\thet\mathbf{a}(T) = 0”没说为什么,我猜是因为优化问题最终就是要让损失函数相对参数的梯度为0
于是反向递推可得到损失函数相对于初始时刻的状态、模型参数和时间 t 的梯度 \mathbf{a}_h(0), \mathbf{a}_\thet\mathbf{a}(0), \mathbf{a}_t(0)。
以优化为例,记,为学习率,因为
\frac{\partial{L}}{\partial{\thet\mathbf{a}(t_0)}} = \frac{\partial{L}}{\partial{\thet\mathbf{a}(t_N)}} - \int_{t_N}^{t_0} \alpha \cdot \frac{\partial{f(\theta)}}{\partial{\thet\mathbf{a}(t)}} dt
其中第一项按前面设置是0,\alpha \cdot \frac{\partial{f(\theta)}}{\partial{\thet\mathbf{a}(t)}}可以在使用f(*args).backward(alpha)
后通过theta.grad
得到,而且很多时候神经网络里的参数都是可以让输出值随便线性调整的(最后一层不是非线性),所以时间间隔具体是多少也不重要,省略了也可以。
3 在监督学习任务中用 ODE 取代残差网络
- 用更少参数达到了 ResNet 相似的准确度
- 达到了常数级的空间复杂度
- RK-Net 是用 Runge-Kutta 前向估算 ODE,但优化时仍用的 BP 算法,所以内存还是高
- 时间复杂度都跟网络深度一致,但 ODE-Net 的深度是自适应的,无法直接比较(但后面实验表明 ODE-Net 更快)
- a 表明 ODE 的数值解的误差阈值与计算量成反比,而计算量与运行时间成正比(b),所以可以训练时减小误差阈值以提高精度,测试时放松阈值以提升速度;
- 反向计算梯度时计算量比前向计算还少一半,作者认为是反向传播时评估节点可能更少