【摘 要】 NEURAL TANGENTS 是一个库,旨在实现对无限宽神经网络的研究。它提供了一个高级 API,用于指定复杂和分层的神经网络架构。然后可以像往常一样以有限宽或无限宽极限对这些网络进行训练和评估。无限宽网络可以使用精确的贝叶斯推断或通过神经正切核使用梯度下降进行分析训练。此外,NEURAL TANGENTS 提供了工具来研究函数空间或权重空间中广泛但有限网络的梯度下降训练动力学。

【原 文】 Novak, R. et al. (2019) ‘Neural Tangents: Fast and Easy Infinite Neural Networks in Python’. arXiv. Available at: http://arxiv.org/abs/1912.02803 (Accessed: 4 March 2023). https://github.com/google/neural-tangents

1 简介

深度神经网络 (DNN) 的成功部分归功于高级、灵活和高效的软件库的广泛可用性,例如 Tensorflow(Abadi 等,2015)、Keras(Chollet 等,2015)、PyTorch.nn (Paszke 等,2017)、Chainer(Tokui 等,2015;Akiba 等,2017)、JAX(Bradbury 等,2018a)等。这些库使研究人员能够通过使用较小的基元构建复杂模型来快速构建复杂模型。新机器学习方法的成功同样取决于开发复杂的软件工具来支持它们。

1.1 无限宽贝叶斯神经网络

最近,一类新的机器学习模型引起了极大的关注,即深度无限宽神经网络。在无限宽极限下,一大类贝叶斯神经网络模型变成了的高斯过程(具有特定的、依赖于体系结构的组合核);这些模型被称为 “神经网络高斯过程” (NNGP)。这种对应关系首先由 Neal (1994) 为浅层全连接网络建立,并在 (Lee 等, 2018; Matthews 等, 2018b) 中扩展到多层设置。从那时起,这种对应关系已扩展到广泛的非线性(Matthews 等,2018a;Novak 等,2019)和包含卷积层(Garriga-Alonso 等,2019;Novak 等, 2019)、残差连接(Garriga-Alonso 等,2019)和池化(Novak 等,2019)的架构。个别架构的结果随后被推广,并且表明高斯过程对应物适用一般性的网络类别,此类网络在 (Yang, 2019) 中被称为 张量编程。在平均场理论和初始化背景下,人们还对定义 NNGP 核的递归关系进行了广泛研究,参见 (Cho & Saul, 2009; Daniely 等, 2016; Poole 等, 2016; Schoenholz 等, 2016; Yang & Schoenholz, 2017; Xiao 等, 2018; Li & Nguyen, 2019; Pretorius 等, 2018; Hayou 等, 2018; Karakida 等, 2018; Blumenfeld 等, 2019; Hayou 等, 2019).

1.2 梯度下降训练的无限宽神经网络

除了启用贝叶斯神经网络的封闭形式描述外,无限宽极限最近还提供了对梯度下降训练神经网络的新成果。去年的几篇论文表明,使用梯度下降方法训练的随机初始化神经网络,可以用与 NNGP 相关的一个分布来刻画,并被描述为 神经正切核 (NTK) (Jacot 等, 2018;Lee 等,2019;Chizat 等,2019),这是隐含在一些早期论文中的核(Li & Liang,2018;Allen-Zhu 等,2018;Du 等,2018a;2018b)。除了这种 “函数空间” 观点之外,Lee 等(2019) 还提出了关于宽网络极限的对偶 “权重空间” 观点,表明梯度下降中的神经网络可以用关于其初始参数的一阶泰勒展开很好地描述。

1.3 使用无限宽网络的实际障碍

这些发现建立了无限宽神经网络,以作为理解深度学习中广泛现象的有用理论工具。此外,在没有可训练核的高斯过程图像分类基准上,这些模型实现了最先进的性能,并且在某些情况下匹配甚至超过了有限宽网络的性能,这进一步证明了此类模型的实用性(Garriga-Alonso 等,2019;Novak 等,2019;Arora 等, 2019a),特别是对于全连接和局部连接的模型族(Lee 等,2018;Novak 等,2019;Arora 等,2019b)。

不过,尽管它们具有实用性,但使用 NNGP 和 NTK-GP 是一项艰巨的任务,可能需要经验丰富的从业者数周甚至数月的工作。对应于神经网络的核必须在体系结构基础上手工导出。总体而言,此过程费力且容易出错,让人想起了高质量自动微分软件爆发之前的神经网络状态。

1.4 主要贡献

在本文中,我们介绍了一个名为 NEURAL TANGENTS 的新开源软件库,以 JAX(Bradbury 等,2018a)为目标对象,加速对无限宽极限神经网络的研究。 NEURAL TANGENTS 的主要特点包括:

  • 具有高级神经网络 API,可用于指定复杂、分层的模型。使用此 API 指定的神经网络可以对其无限宽的 NNGP 核和 NTK 进行分析评估(第 2.1 节,代码 1、2、3,§B.2)。

  • 具有通过神经网络的蒙特卡洛采样来近似无限宽核的可调用函数,通常这些核是无法解析构造的。这些方法与用于构建网络的神经网络软件库无关,因此非常通用(第 2.2 节,图 2,§B.5)。

  • 具有对无线宽神经网络进行解析推断的 API,可以计算贝叶斯后验或计算具有 MSE 损失的连续梯度下降结果。 API 还包括通过数值求解 ODE 来执行推断的各种工具,包括:连续梯度下降、有-或-没有动量、任意损失函数、有限或无限时间(第 2.1 节,图 1,§B.4)。

  • 具有可计算神经网络关于给定参数设置的任意阶泰勒级数逼近的函数,以探索无限宽极限的权重空间视角(§B.6,图 6)。

  • 由于使用了 XLA,我们的库在 CPU、GPU 或 TPU 上开箱即用。核计算可以自动分布在多个加速器上,具有近乎完美的可扩展性( 第 3.2 节,图 5,§B.3)。

我们从三个简短示例 (第 2 节) 开始,这些示例演示了使用 NEURAL TANGENTS 对无限宽神经网络执行计算的简便性、效率和多功能性。有了库的顶层视图,我们将深入到库的更多技术方面 (第 3 节)。

1.5 背景

在这里,我们简要描述一下 NNGP 和 NTK 。

(1)NNGP

神经网络通常被构造为仿射变换,后接逐点的非线性计算。令 zil(x)z^l_i(x) 表示神经网络第 ll 层线性变换之后的第 ii 个激活前。在初始化时,神经网络参数是随机分布的,因此根据中心极限定理:激活前 zil(x)z^l_i(x) 也服从均值为零的高斯分布,并且由其协方差矩阵 k(x,x)=e[zil(x)zil(x)]\mathcal{k}(x, x') = \mathbb{e}[ z^l_i(x) z^l_i(x')] 描述。这显然符合高斯过程的定义,也就是说无线宽极限下的全连接神经网络就是一个具有核 k(x,x)\mathcal{k}(x, x') 的高斯过程,我们称之为神经网络高斯过程( NNGP )。可以使用 NNGP 在测试点 xx 处进行贝叶斯后验预测,根据高斯过程原理,在测试点 xx 处的预测分布应当也是一个高斯分布,其均值为 μ(x)=K(x,X)K(X,X)1Yμ(x) = \mathcal{K}(x, \mathcal{X}) \mathcal{K}(\mathcal{X,X})^{−1} \mathcal{Y},方差为 σ2(x)=K(x,x)K(x,X)K(X,X)1K(X,x)σ^2(x) = \mathcal{K}(x,x) − \mathcal{K}(x,\mathcal{X}) \mathcal{K}(\mathcal{X}, \mathcal{X})^{-1} \mathcal{K}(\mathcal{X}, x),其中 (X,Y)(\mathcal{X,Y}) 分别是输入和目标的训练集。

(2)NTK

当神经网络在均方误差 (MSE) 损失基础上,使用学习率为 ηη 的连续梯度下降优化算法时,在训练点处的函数应当按照 tft(X)=ηJt(X)Jt(X)(ft(X)Y)\partial_t f_t(X ) = −η J_t(\mathcal{X}) J_t(\mathcal{X})^{\top} (f_t(\mathcal{X}) − \mathcal{Y}) 的微分方程形式演化,其中 Jt(X)J_t(\mathcal{X}) 是在 X\mathcal{X} 处计算的输出 ftf_t 的雅可比矩阵,我们用符号 Θt(X,X)=Jt(X)Jt(X)\Theta_t(\mathcal{X,X}) = J_t(\mathcal{X}) J_t(\mathcal{X})^{\top} 表示神经正切核( NTK )。

在无限宽极限条件下,NTK 在整个训练过程中理论上应该保持恒定 (Θt=Θ\Theta_t = \Theta),并且输出的时间演化具有高斯的封闭形式解,其均值为 ft(x)=Θ(x,X)Θ(X,X)1(Iexp[ηΘ(X,X)t])Yf_t(x) = \Theta(x, \mathcal{X}) \Theta( \mathcal{X,X})^{-1} (I − \exp [−η \Theta(\mathcal{X,X})t]) \mathcal{Y}

2 例子

暂略。

Fig1

图 1:与无限网络相比,有限宽度网络集合的训练动态。左图:整个训练过程中训练和测试 MSE 损失演变的均值和方差。右图:经过训练的无限网络的预测与有限宽度网络各自的集成之间的比较。阴影区域和虚线分别表示无限网络和集成预测中不确定性的两个标准偏差。

List1

代码 1:无限 WideResNet 的定义。此代码段同时定义了有限 (init_fn, apply_fn) 和无限 (kernel_fn) 模型。图 2图 3 中使用了该模型。

Fig2

图 2:WideResNet WRN-28-k(其中 k 是加宽因子)NNGP 和 NTK 核(使用 monte_carlo_kernel_fn 计算)的蒙特卡罗 (MC) 估计收敛到它们的分析值(WRN-28-∞,使用kernel_fn ),因为网络通过增加加宽因子(纵轴)变得更宽,并且更多的随机网络在(横轴)上被平均。实验细节。核是在 100×50100×50 批次的 8×88×8 下采样 CIFAR10 (Krizhevsky, 2009) 图像上以 3232 位精度计算的。对于采样效率,对于 NNGP,使用了倒数第二层的输出,对于 NTK,输出层被假定为维度 11(所有 logits 都是独立同分布的, 以给定的输入为条件)。显示的距离是相对 Frobenius 范数的平方,即 KKk,nF2/KF2|\mathcal{K} − \mathcal{K}_{k,n}|^2_F / |\mathcal{K}|_F^2,其中 kk 是加宽因子,nn 是样本数。

Fig3

图 3:具有不同神经网络架构的 CIFAR-10 分类。 NEURAL TANGENTS 简化了架构实验。在这里,我们使用 CIFAR-10 的无限时间 NTK 推断和完整贝叶斯 NNGP 推断,用于完全连接(FC,代码 3)、无池化的卷积网络(CONV,代码 2)和宽残差网络(WRESNET,代码 1)。正如之前工作中常见的那样(Lee 等,2018 年;Novak 等,2019 年),分类任务被视为对零均值目标的 MSE 回归,例如 (0.1,,0.1,0.9,0.1,0.1)(-0.1, \ldots, -0.1, 0.9, - 0.1,\ldots -0.1) 。对于每个训练集大小,通过最小化训练集上的平均负对数边缘似然(NLL,右)来选择系列中的最佳模型。

3 实现:从张量运算转换到核运算

神经网络是基础张量运算(如:稠密或卷积仿射变换、逐点非线性计算、池化或归一化等)的组合。对于大多数在层间没有权重绑定的神经网络而言,核计算也可以组合编写,而且在张量计算和核计算之间存在着直接对应关系(示例参见 第 3.1 节)。 NEURAL TANGENTS 的核心逻辑是一组转换规则,将每个作用在有限宽层上的张量运算,转换为无限宽网络上的核运算。这在 图 4 中针对简单的卷积架构进行了说明。在相关联的列表中,我们将张量运算(第二列)与 NT 和 NNGP 核张量的相应转换(分别为第三列和第四列)进行了比较。当前已经实施转换规则的所有张量运算列表,请参阅 附录 D

设计网络时要考虑的一个微妙之处是:大多数无限宽的结果需要在非线性变换之前进行仿射变换(密集变换或卷积变换)。这是因为无限宽结果通常假设非线性层的激活前近似为高斯分布。权重和偏差的随机性使得无限仿射变换的输出服从高斯分布。幸运的是,在设计神经网络时,将仿射变换放在非线性运算之前是常见的做法,如果不满足此要求,NEURAL TANGENTS 将引发错误。

Fig4

图 4:将卷积神经网络转换为一系列核运算的示例。我们演示了典型神经网络运算对其输入的组合性质如何在 NNGP 和 NT 核上产生相应的组合运算。图中呈现的是一个具有非线性激活 ϕ\phi22 隐层 的 1D 卷积神经网络,对来自数据集 X\mathcal{X}4(1,2,3,4)4(1,2,3,4) 的每个输入 xx,执行 1010 维输出 z2z^2 的回归。为了减少符号,在所有层中都假定单位权重和零偏差方差。顶部:CNN 中的递归输出 (z2)(z^2) 计算(顶部)产生了相应的递归 NNGP 核 (K~2I10)(\tilde{\mathcal{K}}^2 \otimes I_{10}) 计算(NTK 计算类似,但未显示)。底部:每层中张量和相应核操作的显式列表。有关运算的定义参见 表 1。插图和描述改编自 Novak 等(2019)的 图 3

Table1

表 1: 将张量运算转换为 NNGP 和 NTK 核运算的转换规则。这里假定输入张量 X\mathcal{X} 的形状为 X×H×W×C|\mathcal{X}| \times H \times W \times C(数据集大小、高度、宽度、通道数),完整的 NNGP 和 NT 核 K\mathcal{K}T\mathcal{T} 被认为是形状 ( X×H×W)×2|\mathcal{X}| \times H \times W ) \times 2(实际形状 X×2×H×W|\mathcal{X}|^{ \times 2} \times H \times WX×2|\mathcal{X}|^{ \times 2} 也是可能的,这取决于 第 3.2 节 中使用的优化策略)。 符号细节:假定 Tr\text{Tr}GlobalAvgPool\text{GlobalAvgPool} 运算作用于所有空间轴(本例中大小为 HHWW),产生 X×2|\mathcal{X}|^{ \times 2} 的核。类似地,假设 AvgPool\text{AvgPool} 运算也作用于所有空间轴,将指定的步幅 ss、池化窗口大小 pp 和填充策略 pp 应用于 K\mathcal{K}T\mathcal{T} 中的相应轴对(充当具有 2D 版本复制参数的 4D 池化)。T\mathcal{T} 和 ̇TT˙\dot{T\mathcal{T}} 的定义与 Lee 等 (2019)相同, T(Σ)=E[ϕ(u)ϕ(u)]T(\Sigma) = \mathbb{E}[\phi(u)\phi(u)^{\top}] ,T˙(Σ)=E[ϕ(u)ϕ(u)]\dot{\mathcal{T}}(\Sigma) = \mathbb{E}[\phi'(u)\phi'(u)^{\top}], uNN(0,Σ)u \sim N\mathcal{N}(0, \Sigma)。这些表达式可以以封闭形式对许多非线性进行计算,并保持核的形状。A\mathcal{A} 运算的定义类似于 Novak 等 (2019);肖等(2018), [A(Σ)]h,hw,w(x,x)=dh,dw[Σ]h+dh,h+dhw+dw,w+dw(x,x)/q2[\mathcal{A}(\Sigma)]^{w,w'}_{h,h'} (x, x') = \sum_{dh,dw} [\Sigma]^{w+dw,w'+dw}_{h+dh,h'+dh} (x,x')/q^2,其中求和是在具有 qq 个像素的卷积滤波器感受野上执行的(我们在此表达式中假设单位步幅和圆形填充)。 [Σ]n=[Σ,,Σ][\Sigma]∗ n = [\Sigma, \ldots , \Sigma](n 折复制)。有关将转换规则应用于特定模型的示例,请参见图 4,以及示例转换规则的 第 3.1 节

3.1 张量-to-核运算转换的尝试

为了获得转换规则背后的一些直觉,我们考虑非线性后跟密集层的情况。令 z=z(X,θ)Rd×nz = z(\mathcal{X},θ) \in \mathbb{R}^{d \times n} 是神经网络某个隐藏层中节点处从 d 个不同输入产生的激活前。假设 zz 具有 NNGP 核和 NTK:

Kz=Eθ[ziziT],Θz=Eθ[ziθ(ziθ)](1)\mathcal{K}_z = \mathbb{E}_θ [z_iz^T_i] , \qquad \Theta_z = \mathbb{E}_θ \left[ \frac{\partial z_i }{\partial θ} \left(\frac{\partial z_i }{\partial θ} \right)^{\top} \right] \tag{1}

其中 ziRdz_i \in \mathbb{R}^{d} 是第 ii 个神经元,θθ 是网络中直到 zz 为止的参数。这里 dd 是网络输入 X\mathcal{X} 的基数,nnzz 节点中的神经元数量。我们假设 zz 是零均值多元高斯分布。我们希望通过分别计算 y=ϕ(z)y = \phi(z)h=Dense(σω,σb)(y)h = \text{Dense}(σ_ω, σ_b) (y) 的核来计算对应于 h=Dense(σω,σb)(ϕ(z))h = \text{Dense}(σ_ω, σ_b) (\phi(z)) 的核。

h=Dense(σω,σb)(y)(1/n)σωWy+σbβ(2)h=\operatorname{\text{Dense}}\left(\sigma_\omega, \sigma_b\right)(y) \equiv(1 / \sqrt{n}) \sigma_\omega W y+\sigma_b \beta \tag{2}

并且变量 WijW_{ij}βi\beta_i 是独立同分布的,服从高斯 N(0,1)\mathcal{N}(0, 1)。我们将计算表示为 ϕ\phi^*Dense(σω,σb)\text{Dense}(σ_ω, σ_b)^* 的核运算,由张量运算 ϕ\phiDense(σω,σb)\text{Dense}(σ_ω, σ_b) 引起。最后,我们将计算与组合相关的核运算 (Dense(σω,σb)ϕ)=Dense(σω,σb)ϕ(\text{Dense}(σ_ω, σ_b) \circ \phi)^* = \text{Dense}(σ_ω, σ_b)^* \circ \phi^*

首先我们计算 yy 的 NNGP 和 NT 核。要计算 Ky\mathcal{K}_y,请注意根据其定义,

Ky=Kϕ(z)=Eθ[ϕ(z)iϕ(z)iT]=Eθ[ϕ(zi)ϕ(zi)T]=T(Kz)(3)\mathcal{K}_y=\mathcal{K}_{\phi(z)}=\mathbb{E}_\theta\left[\phi(z)_i \phi(z)_i^T\right]=\mathbb{E}_\theta\left[\phi\left(z_i\right) \phi\left(z_i\right)^T\right]=\mathcal{T}\left(\mathcal{K}_z\right) \tag{3}

由于 ϕ\phi 没有引入任何新变量 Θy\Theta_y 可以计算为,

Θy=Eθ[ϕ(zi)θ(ϕ(zi)θ)T]=Eθ[diag(ϕ˙(zi))ziθ(ziθ)Tdiag(ϕ˙(zi))]=T˙(Kz)Θz.\Theta_y=\mathbb{E}_\theta\left[\frac{\partial \phi\left(z_i\right)}{\partial \theta}\left(\frac{\partial \phi\left(z_i\right)}{\partial \theta}\right)^T\right]=\mathbb{E}_\theta\left[\operatorname{diag}\left(\dot{\phi}\left(z_i\right)\right) \frac{\partial z_i}{\partial \theta}\left(\frac{\partial z_i}{\partial \theta}\right)^T \operatorname{diag}\left(\dot{\phi}\left(z_i\right)\right)\right]=\dot{\mathcal{T}}\left(\mathcal{K}_z\right) \odot \Theta_z .

综合这些等式意味着,

(Ky,Θy)=ϕ(Kz,Θz)(T(Kz),T˙(Kz)Θz)\left(\mathcal{K}_y, \Theta_y\right)=\phi^*\left(\mathcal{K}_z, \Theta_z\right) \equiv\left(\mathcal{T}\left(\mathcal{K}_z\right), \dot{\mathcal{T}}\left(\mathcal{K}_z\right) \odot \Theta_z\right)

将是逐点非线性的平移规则。请注意,式 (4) 只对一小部分激活函数 ϕ\phi 有解析表达式。

接下来我们考虑密集操作的情况。使用权重、偏差和 hh 之间的独立性,可以得出,

Kh=EW,β,θ[hihiT]=σω2Eθ[yiyiT]+σb2=σω2Ky+σb2(5)\mathcal{K}_h = \mathbb{E}_{W,β,θ}[h_ih_i^T] = σ^2_ω \mathbb{E}_θ [y_i y^T_i] + σ^2_b = σ^2_ω \mathcal{K}_y + σ^2_b \tag{5}

最后,hh 的 NTK 可以计算为两项的总和:

Θh=EW,β,θ[hi(W,β)(hi(W,β))]+EW,β,θ[hiθ(hiθ)]=σω2Ky+σb2+σω2Θy(6)\Theta_h = \mathbb{E}_{W,β,θ} \left[ \frac{\partial h_i }{\partial (W, β)} \left( \frac{\partial h_i}{\partial (W, β)} \right)^{\top} \right] + \mathbb{E}_{W,β,θ} \left[\frac{ \partial h_i}{\partial θ} \left(\frac{\partial h_i }{ \partial θ} \right)^{\top} \right] = σ^2_ω \mathcal{K}_y + σ^2_b + σ^2_ω \Theta_y \tag{6}

这给出了关于 Ky\mathcal{K}_yΘy\Theta_y 的密集层的转换规则,

(Kh,Θh)=Dense(σω,σb)(Ky,Θy)(σω2Ky+σb2,σω2Ky+σb2+σω2Θy)(7)(\mathcal{K}_h, \Theta_h) = \text{Dense}(σ_ω, σ_b)^* (\mathcal{K}_y, \Theta_y) \equiv (σ^2_ω \mathcal{K}_y + σ^2_b, σ^2_ω \mathcal{K}_y + σ^2_b + σ^2_ω \Theta_y ) \tag{7}

3.2 性能

我们的库在不牺牲灵活性的情况下执行了许多自动性能优化

利用块对角线协方差结构。高斯过程的一个常见计算挑战是训练集协方差矩阵的求逆。对于具有 CC 个类和训练集 X\mathcal{X} 的分类任务,NNGP 和 NTK 协方差的形状为 XC×XC|\mathcal{X}| C \times |\mathcal{X}| C。对于 CIFAR-10,这将是 500,000×500,000500, 000 \times 500, 000。但是,如果使用全连接的读出层(这是分类架构中极为常见的设计),则 CC 个 logits 是独立同分布的。以输入 xx 为条件。这导致输出呈正态分布,具有形式为 ΣIC\Sigma \otimes I_C 的块对角协方差矩阵,其中 4\Sigma4 的形状为 X×X|\mathcal{X}| \times |\mathcal{X}|,且 ICI_CC×CC \times C 单位矩阵。这将许多常见情况下的计算复杂度和存储减少了一个数量级,这使得封闭形式的精确推理在这些情况下变得可行。

仅自动跟踪中间协方差元素的最小必要子集。对于大多数架构,尤其是卷积架构,主要的计算负担在于构建协方差矩阵(而不是求逆矩阵)。专门针对深度为 ll 的卷积网络,构造 X×X|\mathcal{X}| \times |\mathcal{X}| 输出协方差矩阵 Σ\Sigma 涉及计算 ll 个中间层协方差矩阵 Σl\Sigma_l,大小为 Xd×Xd|\mathcal{X}| d \times |\mathcal{X}| d(请参阅 代码 1 了解需要此计算的模型),其中 dd 是中间层输出中的像素总数(例如,对于具有 SAME 填充的 CIFAR-10,d=1024d = 1024)。然而,正如 Xiao 等 (2018); Novak 等 (2019); Garriga-Alonso 等(2019) 指出,如果网络中没有使用池化,则输出协方差 Σ\Sigma 可以仅使用 Σl\Sigma_ldX×Xd |\mathcal{X}| \times |\mathcal{X}| 块的堆叠来计算,将时间和内存成本从 O(X2d2)\mathcal{O}(|\mathcal{X}|^2 d^2) 降低到每层 O(X2d)\mathcal{O}(|\mathcal{X}|^2 d)(参见图 4代码 2,了解承认这种优化的模型)。最后,如果网络没有卷积层,成本会进一步降低到 O(X2)\mathcal{O}(|\mathcal{X}|^2)(示例参见代码 3)。这些选择由 NEURAL TANGENTS 自动执行,以实现高效计算和最小内存占用。

将协方差计算表示为具有最佳布局的 2D 卷积。卷积模型高性能的一个关键见解是,卷积层的协方差传播运算 A\mathcal{A} 可以用二维卷积来表示,当其在 Xd×Xd|\mathcal{X}| d \times |\mathcal{X}| d 的全协方差矩阵 Σ\Sigmadd 对角的 X×X|\mathcal{X}| \times |\mathcal{X}|-块矩阵上运算时。这允许利用现代硬件加速器,其中许多将 2D 卷积作为其主要机器学习应用。

同时进行 NNGP 和 NT 核计算。由于 NTK 计算需要 NNGP 协方差作为中间计算,因此 NNGP 协方差与 NTK 一起计算,无需额外成本。这对于希望研究这两个无限宽神经网络极限之间异同的研究人员来说特别方便。

跨多个设备的自动批处理和并行性。在大多数情况下,随着数据集或模型变大,不可能一次执行整个核计算。此外,在许多情况下,希望跨设备(CPU、GPU 或 TPU)并行化核计算。 NEURAL TANGENTS 提供了一种使用如下所示的单个批处理装饰器来执行这两项常见任务的简单方法:

1
2
batched_kernel_fn = nt.batch(kernel_fn, batch_size) 
batched_kernel_fn(x, x) == kernel_fn(x, x) # True!

此代码适用于解析核或经验核。默认情况下,它会自动在所有可用设备上共享计算。在计算 图 52121 层卷积网络的理论 NTK 时,我们将性能绘制为批量大小和加速器数量的函数,观察到近乎完美的加速器数量扩展。

Fig5

图 5:性能随批量大小(左)和 GPU 数量(右)的变化而变化。显示在具有全局平均池化的 21 层 ReLU 网络中计算分析 NNGP 和 NTK 协方差矩阵(使用 kernel_fn )所需的每个条目的时间。左图:在计算块中的协方差矩阵时增加批量大小可以显著提高性能,直到单个 GPU 中的所有核都饱和时达到某个阈值。预计更简单的模型可以更好地缩放批量大小。右图:每个样本的时间与 GPU 的数量呈线性关系,展示了近乎完美的硬件利用率。

运算融合。 JAX 和 XLA 允许对整个核计算和/或推断进行端到端编译。这使 XLA 编译器能够将低级操作融合到自定义模型特定的加速器核中,并消除从逐个操作分派到加速器的开销。以类似的方式,我们允许协方差张量逐层更改其维度顺序,并跟踪顺序并将其解析为引擎盖下的附加元数据。这通过根据输入元数据调整每个层执行的计算来消除冗余转置。

4 结论

我们相信 NEURAL TANGENTS 将使研究人员能够快速轻松地探索无限宽的网络。通过使这个先前具有挑战性的模型系列大众化,我们希望研究人员在面对新的问题领域(尤其是在数据有限的情况下)时,除了有限的对应物之外,还将开始使用无限的神经网络。此外,我们很高兴看到无限网络作为理论工具的新颖用途,可以深入了解和澄清深度学习中的许多难题。展望未来,我们正在探索 NEURAL TANGENTS 的重要补充。我们希望在未来添加更多层 (§D),从而实现更大范围的无限网络拓扑。此外,我们还希望实现进一步的性能改进,以允许试验更大的模型和数据集。我们邀请社区加入我们的努力,为库贡献新层(§B.7),或将其用于研究并提供反馈!

Fig6

图 6:使用 nt.taylor_expand 训练神经网络及其各种近似值。展示了一个宽度为 51251255 层 Erf 神经网络,使用带动量的 SGD 在 MNIST 上训练,以及关于初始参数的常数(0 阶)、线性(1 阶)和二次(2 阶)泰勒展开。随着训练的进行(从左到右),低阶扩展比高阶扩展更快地偏离原始函数。

List2

代码 2: 图3 中使用的全卷积模型(ConvOnly)定义

List3

代码 3: 图 3 中使用的全连接 (FC) 模型定义。

Fig7

图 7:预测负对数似然和条件数。顶部:在 CIFAR-10(2000 个点的测试集)的无限训练时间内,测试 NNGP 后验的负对数似然和 NTK 的高斯预测分布。全连接(FC,代码 3)和无池化的卷积网络(CONV,代码 2)模型是根据图 3 中的训练边际负对数似然选择的。底部:对应于 NTK/NNGP 的协方差矩阵的条件数以及测试集上的相应预测协方差。由于池化层(Anonymous,2020)导致的 Wide Residual Network 核的病态条件可能是评估该核的预测 NLL 时出现数值问题的原因。