JAX 数组和 PyTrees 的形状和 dtype 的类型注释和运行时检查。
项目描述
打字机
类型注释和运行时检查:
例如:
from jaxtyping import Array, Float, PyTree
# Accepts floating-point 2D arrays with matching dimensions
def matrix_multiply(x: Float[Array, "dim1 dim2"],
y: Float[Array, "dim2 dim3"]
) -> Float[Array, "dim1 dim3"]:
...
def accepts_pytree_of_ints(x: PyTree[int]):
...
def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]):
...
安装
pip install jaxtyping
需要 JAX 0.3.4+。
还要安装你最喜欢的运行时类型检查包。最流行的两个是typeguard(它详尽地检查每个参数)和Beartype(它检查随机的参数片段)。
文档
最后
另请参阅:JAX 生态系统中的其他工具
神经网络:Equinox。
数值微分方程求解器:Diffrax。
SymPy<->JAX 转换;通过梯度下降训练符号表达式:sympy2jax。
致谢
形状注释 + 运行时类型检查的灵感来自于TorchTyping。
简洁的语法部分受到etils.array_types的启发。
免责声明
这不是 Google 的官方产品。
项目详情
下载文件
下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。
源分布
jaxtyping-0.2.7.tar.gz
(14.0 kB
查看哈希)
内置分布
jaxtyping-0.2.7-py3-none-any.whl
(17.9 kB
查看哈希)