高效的高斯神经过程回归
【摘 要】条件神经过程 (CNP)是一个有吸引力的元学习模型系列,它可以产生经过良好校准的预测,能够在测试时进行快速推断,并且可以通过简单的最大似然程序进行训练。 CNP 的局限性在于它们无法对输出中的依赖关系进行建模。这极大地影响了预测性能,并且无法抽取相干的函数样本,从而限制了 CNP 在下游应用和决策制定中的适用性。神经过程 (NPs) 试图通过使用隐变量来缓解这个问题,并靠此来建模输出的依赖性,但带来了近似推断的困难。最近的一种替代方法是 FullConvGNP,它可以对预测中的依赖性进行建模,同时仍然可以通过精确的最大似然法进行训练。不幸的是,FullConvGNP 依赖于昂贵的二维卷积,这使其仅适用于一维数据。在本文工作中,我们提出了一种新方法来模拟输出依赖性,它适用于最大似然训练,但可以扩展到二维和三维数据。所提出的模型在合成实验中表现出了良好性能。
【原 文】 Markou, S. 等 (2021) ‘Efficient Gaussian Neural Processes for Regression’. arXiv. Available at: http://arxiv.org/abs/2108.09676 (Accessed: 23 February 2023).
1 引言
条件神经过程 (CNP; Garnelo 等, 2018a) 是最近提出的一类元学习模型,它有望将神经网络的建模灵活性、稳健性和快速推理与高斯过程的校准不确定性(GPs; Rasmussen)结合起来, 2003). CNP 使用简单的最大似然程序进行训练,并进行复杂度与数据点数量成线性关系的预测。最近的工作通过结合注意力机制(Kim 等,2019 年)或考虑预测问题中的对称性(Gordon 等,2020 年;Kawano 等,2021 年)扩展了 CNP,在各种任务上取得了令人印象深刻的表现。
尽管有这些有利的品质,但 CNP 仅限于对不存在输出依赖性建模的预测,将不同的输入位置视为独立的( 图(1)
)。在本文中,我们将这种预测称为平均场。无法对依赖关系进行建模会影响性能并使 CNP 无法生成连贯的函数样本,从而限制了它们在下游应用程序中的适用性。例如,在降水建模中,我们可能希望评估某些地区每天的降雨量在持续的一段时间内保持在某个指定阈值之上的事件概率,这可能有助于评估发生洪水的可能性。对每个位置独立建模的平均场预测会给此类事件分配不合理的低概率。然而,如果我们能够从预测中抽取连贯的样本,那么就可以更合理地估计这些事件的概率和许多其他有用的量。
图 1. 与 ConvCNP(顶部)仅产生边缘预测不同,ConvGNP(底部)提供具有相关性的预测,并且可用于抽取相干函数样本。
为了解决 CNP 无法在预测中建立依赖关系模型的问题,后续工作(Garnelo 等,2018b)引入了隐变量,引入了近似推理带来的困难(Le 等,2018 年;Foong 等,2020 年) .最近 Bruinsma 等。 (2021) 引入了 CNP 的一种变体,称为高斯神经过程,以下称为 FullConvGNP,它直接参数化输出的预测协方差。然而,对于 D 维数据,FullConvGNP 的架构涉及 2D 维卷积,这是昂贵的,并且对于 ,大多数深度学习库的支持不佳。这项工作引入了一种直接参数化输出依赖项的替代方法,它绕过了 FullConvGNP 的昂贵卷积,并且可以应用于更高维的数据。
2 条件神经过程和高斯神经过程
根据 Foong 等 (2020)的工作,我们从 预测映射 角度介绍条件神经过程(CNP)。预测映射 是一个将: (1) 上下文集合 的函数,其中 是输入, 是输出; (2) 一组目标输入 映射到到相应目标输出 的函数:
其中 是一个向量,它对 上的分布进行参数化。对于固定上下文集 ,使用 Kolmogorov 的扩展定理 (Oksendal, 2013),对于所有 的有限维分布 (f.d.d.s) 的集合定义了一个随机过程,如果这些 f.d.d.s 在 (i) 的任何条目的排列和 (ii) 的任何条目的边缘化下一致。预测映射包括但不限于贝叶斯后验。这种映射的一个熟悉的例子是贝叶斯高斯过程后验
其中 和 由通常的高斯过程后验表达式给出 (Rasmussen, 2003)。另一个预测映射是 CNP (Garnelo 等, 2018a):
其中每个 p(\mathbf{y}_t,m|rm) 都是独立的高斯分布,rm = r(\mathbf{x}_c, \mathbf{y}_c, \mathbf{x}_t,m) 由 DeepSet3 参数化(Zaheer 等,2017)。 CNP 是排列和边缘化一致的,因此对应于有效的随机过程。然而,CNP 通常不遵守产品规则——参见附录 A 和 Foong 等 (2020)。尽管如此,CNP 及其变体 (Gordon 等, 2020) 已被证明可以在各种任务中提供有竞争力的性能和稳健的预测,并且是一类很有前途的元学习模型。
等式中预测的一个核心问题。 (3) 是平均场:eq. (3) 不对 \mathbf{y}_t,m 和 \mathbf{y}_t,m’ 之间的相关性建模,因为 m 6= m’。平均场预测严重损害了预测对数似然。此外,不能使用平均场预测来绘制相干函数样本。为了解决这些问题,我们考虑参数化相关的多元高斯
π (\mathbf{y}_t; \mathbf{x}_c, \mathbf{y}_c, \mathbf{x}_t) = N (\mathbf{y}_t; m, K)
其中,我们使用神经网络来参数化均值 m = m(\mathbf{x}_c, \mathbf{y}_c, \mathbf{x}_t) 和协方差 K = K(\mathbf{x}_c, \mathbf{y}_c, \mathbf{x}_t),而不是贝叶斯高斯过程后验表达式。我们将此类模型称为高斯神经过程 (GNP)。第一个这样的模型,FullConvGNP,由 Bruinsma 等介绍。 (2021) 取得可喜的成果。不幸的是,FullConvGNP 依赖于二维卷积来参数化 K,这很难扩展到更高的维度。为了克服这个困难,我们建议通过以下方式对 m 和 K 进行参数化
mi = f (\mathbf{x}_t,i, r), (5) Kij = k(g(\mathbf{x}_t,i, r), g(\mathbf{x}_t,j, r)) (6)
where r = r(\mathbf{x}_c, \mathbf{y}_c), f and g are neural networks with outputs in R and RDg respectively, and k is an appropriately chosen positive-definite function.
请注意,由于 k 对后验协方差建模,因此它不可能是平稳的。等式 (5) 和 (6) 定义了一类 GNP,与 FullConvGNP 不同,它不需要昂贵的卷积。 GNP 可以很容易地通过对数似然进行训练
θ∗ = arg maxθ log π (\mathbf{y}_t; \mathbf{x}_c, \mathbf{y}_c, \mathbf{x}_t)
也用于 (Garnelo 等, 2018a),其中 θ 收集神经网络 f、g 和 r 的所有参数。在这项工作中,我们考虑了两种参数化 K 的方法。第一种方法是线性协方差 Kij = g(\mathbf{x}_t,i, r)>g(\mathbf{x}_t,j , r) (8)
可以将其解释为具有 Dg 基函数和权重单位高斯分布的参数线性模型。该模型元学习 Dg 上下文相关的基函数,它试图在给定上下文的情况下最好地近似目标的真实分布。根据 Mercer 定理 (Rasmussen, 2003),在满足正则性条件下,每个正定函数 k 都可以分解为 k(z, z′) = ∑∞ d=0 φd(z)φd(z′) (9)
其中 (φd)∞ d=1 是一组正交基函数。因此,我们期望 eq。 (8) 随着 Dg 变大,能够恢复任意(足够规则的)GP 预测。此外,线性协方差具有吸引人的特征,即从中采样与查询位置的数量成线性比例关系。缺点是有限数量的基函数可能会限制其表达能力。回避这个问题的另一种方法是使用 kvv 协方差对 K 进行参数化: Kij = k(g(\mathbf{x}_t,i, r), g(\mathbf{x}_t,j, r))v(\mathbf{x}_t,i, r)v(\mathbf{x}_t ,j,r) (10)
其中 k 是指数二次方 (EQ) 协方差,v 是一个输出为 R 的神经网络。v 因子调节协方差的大小,否则协方差将无法在上下文点附近收缩。与线性不同,kvv 不受有限数量的基函数限制。 kvv 的一个缺点是从中抽取样本的成本在查询位置的数量上成立方比例,这可能会带来重要的实际限制。
linear 和 kvv 都为根据手头的任务选择 f 、 g 和 r 留出了空间,从而产生了 GNP 系列的不同模型的集合。例如,我们可以选择这些作为前馈 DeepSets,从而产生高斯神经过程 (GNP); attentive DeepSets,引起注意力高斯神经过程(AGNPs);或卷积架构,产生了卷积高斯神经过程(ConvGNPs)。在这项工作中,我们探索了这三种替代方案,提出将 ConvGNP 作为 FullConvGNP 的可扩展替代方案。这种方法可以扩展到多个输出,我们将在未来的工作中解决这个问题。
3 实验
我们将所提出的模型应用于从具有各种协方差函数和已知超参数的高斯过程生成的合成数据集。我们将这些数据集子采样到上下文和目标集中,并通过对数似然(方程式(7))进行训练。
我们还训练了 Foong 等讨论的 ANP 和 ConvNP 模型。 (2020)。这些隐变量模型在 r 上放置一个分布 q 并依赖 q 来建模输出依赖性。跟随 Foong 等我们通过目标 θ∗ = arg maxθ log [ Er∼q® [p (\mathbf{y}_t; \mathbf{x}_c, \mathbf{y}_c, \mathbf{x}_t, r) ] ] 的有偏蒙特卡罗估计来训练 ANP 和 ConvNP。 (11)
图(2)
比较了模型的预测对数似然,根据分布数据进行评估,我们从中观察到以下趋势。
依赖关系提高性能:我们预计对输出依赖关系进行建模将使模型能够实现更好的对数似然。事实上,对于一个固定的架构,我们看到相关的 GNPs ( , , , , , ) 通常优于它们的平均场对应物 ( , , )。这一结果令人鼓舞,表明 GNP 可以在实践中学习有意义的依赖关系,在某些情况下可以恢复预言机的性能。
与 FullConvGNP 的比较:相关的 ConvGNPs ( , ) 通常与 FullConvGNP ( ) 具有竞争力。 kvv ConvGNP ( ) 是此处检查的模型中唯一在所有任务中与 FullConvGNP 竞争的模型。然而,与后者不同的是,前者可扩展到 D = 2、3 维。
与 ANP 和 ConvNP 的比较:Correlated GNPs 通常优于隐变量 ANP ( ) 和 ConvNP ( ) 模型,这可以解释为 GNPs 具有高斯预测而 ANP 和 ConvNP 没有,并且所有任务都是高斯的.尽管尝试了不同的架构,甚至与 AGNP ( , ) 和 ConvGNP ( , ) 相比,ANP 和 ConvNP 允许更多的参数,但我们发现很难使隐变量模型与 GNP 竞争。我们通常发现 GNP 家族比这些隐变量模型更容易训练。
Kvv 优于线性模型:我们通常观察到 kvv 模型 ( , , ) 的表现与线性模型 ( , , ) 一样好,有时甚至更好。为了测试线性模型是否受基函数数量 Dg 的限制,我们尝试了各种设置 Dg \in {16, 128, 512, 2048}。我们没有观察到大 Dg 的性能改进,这表明模型不受此因素的限制。这是令人惊讶的,因为当 Dg → ∞ 并假设 f、g 和 r 足够灵活时,根据 Mercer 定理,线性模型应该能够恢复任何(足够规则的)GP 后验。从初步调查来看,我们保留了线性模型可能更难优化并因此难以与 kvv 竞争的可能性。我们希望在未来对我们的训练协议进行更仔细的研究,以确定训练方法是否可以解决这种性能差距,或者 kvv 模型是否从根本上比线性模型更强大。
图(3)
显示了从 GNP 模型中提取的样本,从中我们定性地观察到,与 FullConvGNP 一样,ConvGNP 产生了高质量的函数样本。这些样本与观察到的数据一致,同时保持不确定性并捕获基础过程的行为。 ConvGNP 是唯一产生高质量后验样本的条件模型(除了 FullConvGNP)。 图(4)
显示了模型的协方差图。观察到,与 FullConvGNP 一样,ConvGNP 能够恢复复杂的协方差结构。
4 结论与进一步工作
这项工作介绍了一种替代方法,用于对 CNP 中的相关高斯预测进行参数化。这种方法可以与现有的 CNP 架构相结合,例如前馈 (GNP)、注意力网络 (AGNP) 或卷积网络 (ConvGNP)。与 Bruinsma 等现有的 FullConvGNP 相比,由此产生的模型在计算上更便宜并且更容易扩展到更高的维度。 (2021 年),同时仍然可以通过精确的最大似然法进行训练。 ConvGNP 优于我们在这项工作中考虑的其他条件和隐变量模型,但 FullConvGNP 除外。我们发现输出中的建模依赖性提高了平均场模型的预测对数似然性。它还允许我们绘制连贯的函数样本,这意味着 GNP 可以与更精细的下游估计器链接。与预测性为非分析性的 ANP 和 ConvNP 模型不同,我们希望在 GNP 中评估主动学习获取功能等更加容易和易于处理,这是我们希望在未来工作中探索的一个用例。我们还注意到,尽管 ConvGNP 比 FullConvGNP 表现出更好的缩放比例,但在应用于更高维度时它们仍然需要 2 维或 3 维卷积,这也非常昂贵。我们希望探索降低此成本的方法,以及其他类型的等方差,例如旋转和反射等方差(Kawano 等,2021 年;Holderrieth 等,2020 年)如何在计算上扩展到更高维度便宜的方式。为此,一种类似于 Satorras 等的工作的方法。 (2021) 在等变 GNN 的背景下是一个有前途的方向。我们相信,针对高维数据的廉价且可扩展的条件神经过程在广泛的应用中可能具有很高的价值,包括天气和环境建模、模拟、图形和视觉。