Skip to main content

JAX 数组和 PyTrees 的形状和 dtype 的类型注释和运行时检查。

项目描述

打字机

类型注释和运行时检查

  1. JAX数组的形状和数据类型;
  2. PyTrees

例如:

from jaxtyping import Array, Float, PyTree

# Accepts floating-point 2D arrays with matching dimensions
def matrix_multiply(x: Float[Array, "dim1 dim2"],
                    y: Float[Array, "dim2 dim3"]
                  ) -> Float[Array, "dim1 dim3"]:
    ...

def accepts_pytree_of_ints(x: PyTree[int]):
    ...

def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):
    ...

安装

pip install jaxtyping

需要 JAX 0.3.4+。

还要安装你最喜欢的运行时类型检查包。最流行的两个是typeguard(它详尽地检查每个参数)和Beartype(它检查随机的参数片段)。

文档

完整的 API 参考

FAQ(静态类型检查、flake8 等)

最后

另请参阅:JAX 生态系统中的其他工具

神经网络:Equinox

数值微分方程求解器:Diffrax

SymPy<->JAX 转换;通过梯度下降训练符号表达式:sympy2jax

致谢

形状注释 + 运行时类型检查的灵感来自于TorchTyping

简洁的语法部分受到etils.array_types的启发。

免责声明

这不是 Google 的官方产品。

项目详情


下载文件

下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。

源分布

jaxtyping-0.2.7.ta​​r.gz (14.0 kB 查看哈希

已上传 source

内置分布

jaxtyping-0.2.7-py3-none-any.whl (17.9 kB 查看哈希

已上传 py3