将 SymPy 表达式转换为可训练的 JAX 表达式。
项目描述
sympy2jax
将 SymPy 表达式转换为可训练的 JAX 表达式。输出将是一个Equinox模块,其中所有 SymPy 浮点数(整数、有理数,...)作为叶子。SymPy 符号将作为输入。
通过梯度下降优化您的符号表达式!
安装
pip install sympy2jax
要求:
Python 3.7+
JAX 0.3.4+
Equinox 0.5.3+
SymPy 1.7.1+。
例子
import jax
import sympy
import sympy2jax
x_sym = sympy.symbols("x_sym")
cosx = 1.0 * sympy.cos(x_sym)
sinx = 2.0 * sympy.sin(x_sym)
mod = sympy2jax.SymbolicModule([cosx, sinx]) # PyTree of input expressions
x = jax.numpy.zeros(3)
out = mod(x_sym=x) # PyTree of results.
params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
# (Which may be trained in the usual way for Equinox.)
文档
sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)
在哪里:
expressions是 SymPy 表达式的 PyTree。extra_funcs是从 SymPy 函数到 JAX 操作的可选字典,用于扩展内置翻译规则。make_array是整数/浮点数/有理数是否应存储为 Python 整数/等,或 JAX 数组。
可以使用符号-值的键值对调用实例,如上例所示。
实例具有.sympy()将模块转换回 SymPy 表达式的 PyTree 的方法。
(这实际上是整个文档,非常简单。)
最后
另请参阅:JAX 生态系统中的其他工具
神经网络:Equinox。
数值微分方程求解器:Diffrax。
PyTrees 和 JAX 数组的 shape/dtype 的类型注释和运行时检查:jaxtyping。
免责声明
这不是 Google 的官方产品。
项目详情
下载文件
下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。
源分布
sympy2jax-0.0.4.tar.gz
(8.8 kB
查看哈希)
内置分布
sympy2jax-0.0.4-py3-none-any.whl
(9.3 kB
查看哈希)