GPyTorch 中可扩展高斯过程的格核
项目描述
单纯型GP
此存储库托管SKIing on Simlices 的代码: Sanyam Kapoor、Marc Finzi、 Ke Alexander Wang、 Andrew Gordon Wilson的可伸缩高斯过程(Simplex-GPs) 的 Permutohedral Lattice 上的内核插值。
理念
快速矩阵向量乘法 (MVM) 是现代可扩展高斯过程的基石。通过建立在结构化内核插值(SKI)提出的近似值的基础上 ,并利用快速高维图像滤波方面的进步,Simplex-GPs 通过使用稀疏置换晶格而不是矩形网格平铺空间来近似内核矩阵的计算.
SKI 中的核运算所隐含的矩阵向量积现在通过上面可视化的三个阶段来近似 ——splat(投影到 permutohedral 晶格上)、 blur(将模糊运算作为矩阵向量积应用)和 slice(重新投影回原来的空间)。
这减轻了与 SKI 操作相关的维度灾难,允许它们扩展到超过约 5 个维度,并在运行时和内存成本方面提供竞争优势,而下游性能的损失很小。有关完整的详细信息,请参阅我们的手稿。
用法
lattice 内核被打包为 GPyTorch 模块,可以RBFKernel
用作MaternKernel
. 对应的替换模块是RBFLattice
和MaternLattice
。
RBFLattice
内核很容易通过更改一行代码来使用:
import gpytorch as gp
from gpytorch_lattice_kernel import RBFLattice
class SimplexGPModel(gp.models.ExactGP):
def __init__(self, train_x, train_y):
likelihood = gp.likelihoods.GaussianLikelihood()
super().__init__(train_x, train_y, likelihood)
self.mean_module = gp.means.ConstantMean()
self.covar_module = gp.kernels.ScaleKernel(
- gp.kernels.RBFKernel(ard_num_dims=train_x.size(-1))
+ RBFLattice(ard_num_dims=train_x.size(-1), order=1)
)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gp.distributions.MultivariateNormal(mean_x, covar_x)
GPyTorch回归教程 提供了一个关于玩具数据的更简单示例,该内核可用作替代品。
安装
要在代码中使用内核,请将软件包安装为:
pip install gpytorch-lattice-kernel
注意:内核是使用CMake从源代码延迟编译的。如果编译失败,您可能需要安装更新的版本。此外,ninja
编译需要。一种安装方法是:
conda install -c conda-forge cmake ninja
本地设置
对于本地开发设置,创建conda
环境
$ conda env create -f environment.yml
如果还没有,请记住将项目的根目录添加到 PYTHONPATH。
$ export PYTHONPATH="$(pwd):${PYTHONPATH}"
测试
为了验证代码是否按预期工作,提供了一个简单的测试文件 ,用于测试 Simplex-GPs 和 Exact-GPs 实现的训练边际可能性。运行为:
python tests/train_snelson.py
使用Snelson 1-D 玩具数据集。snelson.csv中有一份副本。
结果
建议的内核可以像往常一样与 GPyTorch 一起使用。重现结果的示例脚本是,
python experiments/train_simplexgp.py --dataset=elevators --data-dir=<path/to/uci/data/mat/files>
我们使用Fire来处理 CLI 参数。因此,函数的所有参数main
都是 CLI 的有效参数。
论文中的所有数字都可以通过笔记本复制。
注意:UCI 数据集mat
文件可在此处获得。
执照
阿帕奇 2.0
项目详情
下载文件
下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。
源分布
内置分布
gpytorch_lattice_kernel -0.0.dev1-py3-none-any.whl 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | fe7eadcfa48aefecb0d310368c94c98a25776718491583d91268e6bbaf2fd977 |
|
MD5 | f96aad3055257d4f7ddaa32e26cf7282 |
|
布莱克2-256 | e59d8db1ee9db20b94a61e76de4e90ef2e3c8ba42fed16586f7cb0c0768c4581 |