JAX 的神经网络
项目描述
JAXnet
JAXnet 是一个基于JAX的深度学习库。JAXnet 的功能 API 提供了优于 TensorFlow2、Keras 和 PyTorch 的独特优势,同时保持了用户友好性、模块化和可扩展性:
- 通过不可变的权重提高鲁棒性,没有全局计算图。
numpy用于网络、训练循环、预处理和后处理的GPU 编译代码。- 一行中的任何模块或整个网络的正则化和重新参数化。
- 无全局随机状态,灵活的随机密钥控制。
如果您已经了解 stax,请阅读本文。
模块化
net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), log_softmax)
从预定义的模块创建神经网络模型。
可扩展性
@parametrized使用函数定义您自己的模块。您可以重用其他模块:
from jax import numpy as jnp
@parametrized
def loss(inputs, targets):
return -jnp.mean(net(inputs) * targets)
所有模块都以这种方式组成。
jax.numpy是 mirroring numpy,意思是如果你知道如何使用numpy,你就知道 JAXnet 的大部分内容。将此与 TensorFlow2/Keras 进行比较:
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Lambda
net = Sequential([Dense(1024, 'relu'), Dense(1024, 'relu'), Dense(4), Lambda(tf.nn.log_softmax)])
def loss(inputs, targets):
return -tf.reduce_mean(net(inputs) * targets)
注意LambdaJAXnet 中不需要层。
relu并且logsoftmax是普通的 Python 函数。
不可变的权重
与 TensorFlow2/Keras 不同,JAXnet 没有全局计算图。模块喜欢net并且loss不包含可变权重。相反,权重包含在单独的、不可变的对象中。它们使用init_parameters、提供的示例输入和随机密钥进行初始化:
from jax.random import PRNGKey
def next_batch(): return jnp.zeros((3, 784)), jnp.zeros((3, 4))
params = loss.init_parameters(*next_batch(), key=PRNGKey(0))
print(params.sequential.dense2.bias) # [-0.01101029, -0.00749435, -0.00952365, 0.00493979]
优化器不是内联改变权重,而是返回更新版本的权重。它们作为新优化器状态的一部分返回,并且可以通过以下方式检索get_parameters:
opt = optimizers.Adam()
state = opt.init(params)
for _ in range(10):
state = opt.update(loss.apply, state, *next_batch()) # accelerate with jit=True
trained_params = opt.get_parameters(state)
apply评估网络:
test_loss = loss.apply(trained_params, *test_batch) # accelerate with jit=True
GPU 支持和编译
JAX 允许加速功能numpy/代码。scipy通过将numpy导入替换为jax.numpy. 通过用 . 装饰函数来编译它jit。这将使您的函数从缓慢的 Python 解释中解放出来,尽可能并行化操作并优化您的计算图。它提供了 TensorFlow2 或 PyTorch 级别的速度和可扩展性。
由于权重不可变,整个训练循环可以在 GPU 上编译/运行(演示)。
jit将使您的训练与内联改变权重一样快,并且权重不会离开 GPU。您可以编写函数式代码而不必担心性能。
您可以以相同的方式轻松加速numpy/scipy预处理/后处理代码(演示)。
正则化和重新参数化
在 JAXnet 中,可以在一行中完成模型的正则化(演示):
loss = L2Regularized(loss, scale=.1)
loss现在只是另一个可以像上面一样使用的模块。重新参数化的层也是单行的(参见API)。JAXnet 允许在不更改其代码的情况下对任何模块或子网进行正则化或重新参数化。这是可能的,因为模块不实例化任何变量。取而代之的是,每个模块都提供了一个apply带有参数作为参数的函数 ( )。可以包装此函数以构建层,例如L2Regularized.
相比之下,TensorFlow2/Keras/PyTorch 的模型 API 中包含可变变量。因此,他们要求:
随机键控制
JAXnet 没有全局随机状态。随机键是显式提供的,使代码具有确定性,并且默认情况下独立于先前执行的代码。这可以帮助调试并且更灵活(演示)。在此处阅读有关 JAX 中随机数的更多信息。
逐步调试
JAXnet 允许使用具体值进行分步调试,就像任何普通的 Python 函数一样(jit不使用编译时)。
API 和演示
在此处查找有关 API 的更多详细信息。
在浏览器中查看 JAXnet: Mnist Classifier、 Mnist VAE、 带有 RNNs 的 OCR、 ResNet、 WaveNet、 PixelCNN++和 Policy Gradient RL。
安装
这是预览。期待突破性的变化!支持 Python 3.6 或更高版本。安装
pip3 install jaxnet
要使用 GPU,首先安装正确版本的 jaxlib。
问题
请随时在 GitHub 上创建问题。
项目详情
下载文件
下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。