Skip to main content

将 SymPy 表达式转换为可训练的 JAX 表达式。

项目描述

sympy2jax

将 SymPy 表达式转换为可训练的 JAX 表达式。输出将是一个Equinox模块,其中所有 SymPy 浮点数(整数、有理数,...)作为叶子。SymPy 符号将作为输入。

通过梯度下降优化您的符号表达式!

安装

pip install sympy2jax

要求:
Python 3.7+
JAX 0.3.4+
Equinox 0.5.3+
SymPy 1.7.1+。

例子

import jax
import sympy
import sympy2jax

x_sym = sympy.symbols("x_sym")
cosx = 1.0 * sympy.cos(x_sym)
sinx = 2.0 * sympy.sin(x_sym)
mod = sympy2jax.SymbolicModule([cosx, sinx])  # PyTree of input expressions

x = jax.numpy.zeros(3)
out = mod(x_sym=x)  # PyTree of results.
params = jax.tree_leaves(mod)  # 1.0 and 2.0 are parameters.
                               # (Which may be trained in the usual way for Equinox.)

文档

sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)

在哪里:

  • expressions是 SymPy 表达式的 PyTree。
  • extra_funcs是从 SymPy 函数到 JAX 操作的可选字典,用于扩展内置翻译规则。
  • make_array是整数/浮点数/有理数是否应存储为 Python 整数/等,或 JAX 数组。

可以使用符号-值的键值对调用实例,如上例所示。

实例具有.sympy()将模块转换回 SymPy 表达式的 PyTree 的方法。

(这实际上是整个文档,非常简单。)

最后

另请参阅:JAX 生态系统中的其他工具

神经网络:Equinox

数值微分方程求解器:Diffrax

PyTrees 和 JAX 数组的 shape/dtype 的类型注释和运行时检查:jaxtyping

免责声明

这不是 Google 的官方产品。

项目详情


下载文件

下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。

源分布

sympy2jax-0.0.4.tar.gz (8.8 kB 查看哈希

已上传 source

内置分布

sympy2jax-0.0.4-py3-none-any.whl (9.3 kB 查看哈希

已上传 py3