Chex:在 JAX 中测试变得有趣!
项目描述
切克斯
Chex 是一个实用程序库,用于帮助编写可靠的 JAX 代码。
这包括帮助:
- 检测您的代码(例如断言)
- 调试(例如
pmaps
在vmaps
上下文管理器中转换)。 - 跨多个测试 JAX 代码
variants
(例如,jitted 与 non-jitted)。
安装
您可以通过以下方式从 PyPI 安装最新发布的 Chex 版本:
pip install chex
或者您可以从 GitHub 安装最新的开发版本:
pip install git+https://github.com/deepmind/chex.git
模块概述
数据类(dataclass.py )
数据类是 Python 3.7 引入的一种流行结构,允许使用最少的样板代码轻松指定类型化数据结构。但是,它们与 开箱即用的 JAX 和dm-tree不兼容。
在 Chex 中,我们提供了一个重用 python 数据类的 JAX 友好的数据类实现。
dataclass
将数据类注册为内部PyTree
节点的 Chex 实现,以确保与 JAX 数据结构的兼容性。
此外,我们提供了一个类包装器,它将数据类公开为
collections.Mapping
后代,允许在dm-tree
方法中像通常的 Python 字典一样处理它们(例如(未)展平)。@mappable_dataclass
有关详细信息,请参阅文档字符串。
例子:
@chex.dataclass
class Parameters:
x: chex.ArrayDevice
y: chex.ArrayDevice
parameters = Parameters(
x=jnp.ones((2, 2)),
y=jnp.ones((1, 2)),
)
# Dataclasses can be treated as JAX pytrees
jax.tree_util.tree_map(lambda x: 2.0 * x, parameters)
# and as mappings by dm-tree
tree.flatten(parameters)
注意:与标准 Python 3.7 数据类不同,Chex 数据类不能使用位置参数构造。它们支持以与 Python dict 构造函数相同的格式提供的构造参数。from_tuple
如有必要,可以使用和
to_tuple
方法将数据类转换为元组。
parameters = Parameters(
jnp.ones((2, 2)),
jnp.ones((1, 2)),
)
# ValueError: Mappable dataclass constructor doesn't support positional args.
断言(asserts.py)
JAX 的 PyType 注释的一个限制是它们不支持DeviceArray
等级、形状或 dtype 的规范。Chex 包含许多函数,这些函数允许对这些属性进行灵活和简洁的规范。
例如,假设您要确保所有张量t1
, t2
,t3
具有相同的形状,并且张量t4
,分别t5
具有秩2
和 (3
或4
)。
chex.assert_equal_shape([t1, t2, t3])
chex.assert_rank([t4, t5], [2, {3, 4}])
更多示例:
from chex import assert_shape, assert_rank, ...
assert_shape(x, (2, 3)) # x has shape (2, 3)
assert_shape([x, y], [(), (2,3)]) # x is scalar and y has shape (2, 3)
assert_rank(x, 0) # x is scalar
assert_rank([x, y], [0, 2]) # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2}) # x and y are scalar OR rank-2 arrays
assert_type(x, int) # x has type `int` (x can be an array)
assert_type([x, y], [int, float]) # x has type `int` and y has type `float`
assert_equal_shape([x, y, z]) # x, y, and z have equal shapes
assert_trees_all_close(tree_x, tree_y) # values and structure of trees match
assert_tree_all_finite(tree_x) # all tree_x leaves are finite
assert_devices_available(2, 'gpu') # 2 GPUs available
assert_tpu_available() # at least 1 TPU available
assert_numerical_grads(f, (x, y), j) # f^{(j)}(x, y) matches numerical grads
请参阅asserts.py
文档以查找所有支持的断言。
如果您找不到特定的断言,请考虑提出拉取请求或 在错误跟踪器上打开问题。
可选参数
所有 chex 断言都支持以下可选 kwargs 来操作发出的异常消息:
custom_message
:要包含在发出的异常消息中的字符串。include_default_message
: 是否在发出的异常消息中包含默认的 Hex 消息。exception_type
: 要使用的异常类型。AssertionError
默认。
例如,下面的代码:
dataset = load_dataset()
params = init_params()
for i in range(num_steps):
params = update_params(params, dataset.sample())
chex.assert_tree_all_finite(params,
custom_message=f'Failed at iteration {i}.',
exception_type=ValueError)
当被or s污染时,将引发ValueError
包含步数的 a。params
NaNs
None
静态和值(又名运行时)断言
Chex 将所有断言分为 2 类:静态断言和值 断言。
-
静态断言使用除了张量的具体值之外的任何东西。例子:
assert_shape
,assert_trees_all_equal_dtypes
,assert_max_traces
. -
值断言需要访问张量值,这些值在 JAX 跟踪期间不可用(请参阅 JAX 原语如何工作),因此此类断言需要在 jitted代码中进行特殊处理
要在 jited 函数中启用值断言,可以使用
chex.chexify()
wrapper 对其进行修饰。例子:
@chex.chexify
@jax.jit
def logp1_abs_safe(x: chex.Array) -> chex.Array:
chex.assert_tree_all_finite(x)
return jnp.log(jnp.abs(x) + 1)
logp1_abs_safe(jnp.ones(2)) # OK
logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS (in async mode)
# The error will be raised either at the next line OR at the next
# `logp1_abs_safe` call. See the docs for more detain on async mode.
logp1_abs_safe.wait_checks() # Wait for the (async) computation to complete.
有关
. _
_chex.chexify()
JAX 跟踪断言
每次传递参数的结构发生变化时,JAX 都会重新跟踪 JIT 函数。通常这种行为是无意的,并导致难以调试的显着性能下降。@chex.assert_max_tracesn
装饰器断言该函数在程序执行期间
不会被重新跟踪多次。
可以通过调用清除全局跟踪计数器
chex.clear_trace_counter()
。此函数用于隔离依赖于@chex.assert_max_traces
.
例子:
@jax.jit
@chex.assert_max_traces(n=1)
def fn_sum_jitted(x, y):
return x + y
z = fn_sum_jitted(jnp.zeros(3), jnp.zeros(3))
t = fn_sum_jitted(jnp.zeros(6, 7), jnp.zeros(6, 7)) # AssertionError!
也可以用于jax.pmap()
:
def fn_sub(x, y):
return x - y
fn_sub_pmapped = jax.pmap(chex.assert_max_traces(fn_sub, n=10))
有关跟踪的更多信息,请参阅 JAX 原语如何工作 部分。
测试变体(variants.py)
JAX 广泛依赖于代码转换和编译,这意味着很难确保代码得到正确测试。例如,仅使用 JAX 代码测试 python 函数将不会覆盖 jitted 时执行的实际代码路径,并且无论代码是针对 CPU、GPU 还是 TPU 进行 jitted,该路径也会有所不同。这是一个晦涩难懂且难以捕捉的 bug 的来源,其中 XLA 更改会导致不良行为,但仅在一个特定的代码转换中表现出来。
变体通过提供一个简单的装饰器,可以很容易地确保单元测试涵盖函数的不同“变体”,该装饰器可用于在所有(或子集)相关代码转换下重复任何测试。
例如,假设您想测试fn
有或没有 jit 的函数的输出。您可以使用chex.variants
函数的 jitted 和非 jitted 版本运行测试,只需用 装饰测试方法
@chex.variants
,然后在测试主体中使用self.variant(fn)
代替。fn
def fn(x, y):
return x + y
...
class ExampleTest(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
def test(self):
var_fn = self.variant(fn)
self.assertEqual(fn(1, 2), 3)
self.assertEqual(var_fn(1, 2), fn(1, 2))
如果在测试方法中定义函数,也可以self.variant
在函数定义中作为装饰器使用。例如:
class ExampleTest(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
def test(self):
@self.variant
def var_fn(x, y):
return x + y
self.assertEqual(var_fn(1, 2), 3)
参数化测试示例:
from absl.testing import parameterized
# Could also be:
# `class ExampleParameterizedTest(chex.TestCase, parameterized.TestCase):`
# `class ExampleParameterizedTest(chex.TestCase):`
class ExampleParameterizedTest(parameterized.TestCase):
@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
('case_positive', 1, 2, 3),
('case_negative', -1, -2, -3),
)
def test(self, arg_1, arg_2, expected):
@self.variant
def var_fn(x, y):
return x + y
self.assertEqual(var_fn(arg_1, arg_2), expected)
Chex 当前支持以下变体:
with_jit
--jax.jit()
对函数应用转换。without_jit
-- 按原样使用函数,即恒等变换。with_device
-- 在应用函数之前,将所有参数(参数中指定的除外ignore_argnums
)放入设备内存中。without_device
-- 在应用函数之前将所有参数放在 RAM 中。with_pmap
--jax.pmap()
对函数应用转换(见下面的注释)。
有关支持的变体的更多详细信息,请参阅 variables.py 中的文档。更多示例可以在variables_test.py中找到。
变体注释
-
使用的测试类
@chex.variants
必须继承自chex.TestCase
(或在 中展开测试生成器的任何其他基类TestCase
,例如absl.testing.parameterized.TestCase
)。 -
[
jax.vmap
]所有变体都可以应用于 vmapped 函数;请参阅variables_test.py (test_vmapped_fn_named_params
和test_pmap_vmapped_fn
)中的示例。 -
[
@chex.all_variants
]您可以使用装饰器获得所有支持的变体@chex.all_variants
。 -
[
with_pmap
变体]jax.pmap(fn)
( docfn
)在多个设备上执行并行映射。由于大多数测试在单设备环境中运行(即可以访问单个 CPU 或 GPU),在这种情况下,它jax.pmap
的功能等同于jax.jit
,with_pmap
因此默认情况下会跳过变体(尽管它适用于单个设备)。下面我们描述了一种正确测试fn
它是否应该在多设备环境(TPU 或多个 CPU/GPU)中使用的方法。with_pmap
要在单个设备的情况下 禁用跳过 变体,请添加--chex_skip_pmap_variant_if_single_device=false
到您的测试命令。
假货 ( fake.py )
代码转换(例如jit
和)使 JAX 中的调试变得更加困难pmap
,这些转换引入了使代码难以检查和跟踪的优化。在调试期间禁用这些转换也很困难,因为它们可以在底层代码的多个位置调用。Chex 提供了jax.jit
用 no-op 转换和jax.pmap
(non-parallel)全局替换的工具jax.vmap
,以便在单设备上下文中更轻松地调试代码。
例如,您可以使用 Chex 进行伪造pmap
并将其替换为vmap
. 这可以通过使用上下文管理器包装您的代码来实现:
with chex.fake_pmap():
@jax.pmap
def fn(inputs):
...
# Function will be vmapped over inputs
fn(inputs)
也可以使用start
and调用相同的功能stop
:
fake_pmap = chex.fake_pmap()
fake_pmap.start()
... your jax code ...
fake_pmap.stop()
此外,您可以使用多线程 CPU 伪造一个真正的多设备测试环境。有关更多详细信息,请参阅伪造多设备测试环境部分。
有关更多详细信息,请参阅 fake.py 中的文档和fake_test.py中的示例。
伪造多设备测试环境
在您无法轻松访问多个设备的情况下,您仍然可以使用单设备多线程测试并行计算。
特别是,可以强制 XLA 将单个 CPU 的线程用作单独的设备,即用多线程环境伪造一个真正的多设备环境。从 XLA 的角度来看,这两个选项在理论上是等效的,因为它们公开了相同的接口并使用了相同的抽象。
Chex 有一个标志chex_n_cpu_devices
,用于指定用作 XLA 设备的 CPU 线程数。
要为测试设置多线程 XLA 环境,请在测试模块中absl
定义
setUpModule
函数:
def setUpModule():
chex.set_n_cpu_devices()
现在您可以启动您的测试以python test.py --chex_n_cpu_devices=N
在多设备状态下运行它。请注意,模块中的所有测试都可以访问N
设备。
更多示例可以在variables_test.py、fake_test.py和fake_set_n_cpu_devices_test.py中找到。
使用命名的尺寸大小。
Chex 带有一个小实用程序,允许您将一组尺寸大小打包到一个对象中。基本思想是:
dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim)
...
chex.assert_shape(arr, dims['BTE'])
字符串查找是翻译的整数元组。例如,假设
batch_size == 3
,sequence_len = 5
和embedding_dim = 7
, 那么
dims['BTE'] == (3, 5, 7)
dims['B'] == (3,)
dims['TTBEE'] == (5, 5, 3, 7, 7)
...
您还可以动态分配尺寸大小,如下所示:
dims['XY'] = some_matrix.shape
dims.Z = 13
有关更多示例,请参阅chex.Dimensions 文档。
引用 Chex
此存储库是DeepMind JAX 生态系统的一部分,要引用 Chex,请使用DeepMind JAX 生态系统引用。