Skip to main content

Chex:在 JAX 中测试变得有趣!

项目描述

切克斯

CI 状态 文档 皮皮

Chex 是一个实用程序库,用于帮助编写可靠的 JAX 代码。

这包括帮助:

  • 检测您的代码(例如断言)
  • 调试(例如pmapsvmaps上下文管理器中转换)。
  • 跨多个测试 JAX 代码variants(例如,jitted 与 non-jitted)。

安装

您可以通过以下方式从 PyPI 安装最新发布的 Chex 版本:

pip install chex

或者您可以从 GitHub 安装最新的开发版本:

pip install git+https://github.com/deepmind/chex.git

模块概述

数据类(dataclass.py )

数据类是 Python 3.7 引入的一种流行结构,允许使用最少的样板代码轻松指定类型化数据结构。但是,它们与 开箱即用的 JAX 和dm-tree不兼容。

在 Chex 中,我们提供了一个重用 python 数据类的 JAX 友好的数据类实现

dataclass将数据类注册为内部PyTree 节点的 Chex 实现,以确保与 JAX 数据结构的兼容性。

此外,我们提供了一个类包装器,它将数据类公开为 collections.Mapping后代,允许在dm-tree方法中像通常的 Python 字典一样处理它们(例如(未)展平)。@mappable_dataclass 有关详细信息,请参阅文档字符串。

例子:

@chex.dataclass
class Parameters:
  x: chex.ArrayDevice
  y: chex.ArrayDevice

parameters = Parameters(
    x=jnp.ones((2, 2)),
    y=jnp.ones((1, 2)),
)

# Dataclasses can be treated as JAX pytrees
jax.tree_util.tree_map(lambda x: 2.0 * x, parameters)

# and as mappings by dm-tree
tree.flatten(parameters)

注意:与标准 Python 3.7 数据类不同,Chex 数据类不能使用位置参数构造。它们支持以与 Python dict 构造函数相同的格式提供的构造参数。from_tuple如有必要,可以使用和 to_tuple方法将数据类转换为元组。

parameters = Parameters(
    jnp.ones((2, 2)),
    jnp.ones((1, 2)),
)
# ValueError: Mappable dataclass constructor doesn't support positional args.

断言(asserts.py

JAX 的 PyType 注释的一个限制是它们不支持DeviceArray等级、形状或 dtype 的规范。Chex 包含许多函数,这些函数允许对这些属性进行灵活和简洁的规范。

例如,假设您要确保所有张量t1, t2,t3具有相同的形状,并且张量t4,分别t5具有秩2和 (34)。

chex.assert_equal_shape([t1, t2, t3])
chex.assert_rank([t4, t5], [2, {3, 4}])

更多示例:

from chex import assert_shape, assert_rank, ...

assert_shape(x, (2, 3))                # x has shape (2, 3)
assert_shape([x, y], [(), (2,3)])      # x is scalar and y has shape (2, 3)

assert_rank(x, 0)                      # x is scalar
assert_rank([x, y], [0, 2])            # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2})            # x and y are scalar OR rank-2 arrays

assert_type(x, int)                    # x has type `int` (x can be an array)
assert_type([x, y], [int, float])      # x has type `int` and y has type `float`

assert_equal_shape([x, y, z])          # x, y, and z have equal shapes

assert_trees_all_close(tree_x, tree_y) # values and structure of trees match
assert_tree_all_finite(tree_x)         # all tree_x leaves are finite

assert_devices_available(2, 'gpu')     # 2 GPUs available
assert_tpu_available()                 # at least 1 TPU available

assert_numerical_grads(f, (x, y), j)   # f^{(j)}(x, y) matches numerical grads

请参阅asserts.py 文档以查找所有支持的断言。

如果您找不到特定的断言,请考虑提出拉取请求或 在错误跟踪器上打开问题。

可选参数

所有 chex 断言都支持以下可选 kwargs 来操作发出的异常消息:

  • custom_message:要包含在发出的异常消息中的字符串。
  • include_default_message: 是否在发出的异常消息中包含默认的 Hex 消息。
  • exception_type: 要使用的异常类型。AssertionError默认。

例如,下面的代码:

dataset = load_dataset()
params = init_params()
for i in range(num_steps):
  params = update_params(params, dataset.sample())
  chex.assert_tree_all_finite(params,
                              custom_message=f'Failed at iteration {i}.',
                              exception_type=ValueError)

当被or s污染时,将引发ValueError包含步数的 a。paramsNaNsNone

静态和值(又名运行时)断言

Chex 将所有断言分为 2 类:静态断言和 断言。

  1. 静态断言使用除了张量的具体值之外的任何东西。例子:assert_shape,assert_trees_all_equal_dtypes, assert_max_traces.

  2. 断言需要访问张量值,这些值在 JAX 跟踪期间不可用(请参阅 JAX 原语如何工作),因此此类断言需要在 jitted代码中进行特殊处理

要在 jited 函数中启用值断言,可以使用 chex.chexify()wrapper 对其进行修饰。例子:

  @chex.chexify
  @jax.jit
  def logp1_abs_safe(x: chex.Array) -> chex.Array:
    chex.assert_tree_all_finite(x)
    return jnp.log(jnp.abs(x) + 1)

  logp1_abs_safe(jnp.ones(2))  # OK
  logp1_abs_safe(jnp.array([jnp.nan, 3]))  # FAILS (in async mode)

  # The error will be raised either at the next line OR at the next
  # `logp1_abs_safe` call. See the docs for more detain on async mode.
  logp1_abs_safe.wait_checks()  # Wait for the (async) computation to complete.

有关 . _ _chex.chexify()

JAX 跟踪断言

每次传递参数的结构发生变化时,JAX 都会重新跟踪 JIT 函数。通常这种行为是无意的,并导致难以调试的显着性能下降。@chex.assert_max_tracesn装饰器断言该函数在程序执行期间 不会被重新跟踪多次。

可以通过调用清除全局跟踪计数器 chex.clear_trace_counter()。此函数用于隔离依赖于@chex.assert_max_traces.

例子:

  @jax.jit
  @chex.assert_max_traces(n=1)
  def fn_sum_jitted(x, y):
    return x + y

  z = fn_sum_jitted(jnp.zeros(3), jnp.zeros(3))
  t = fn_sum_jitted(jnp.zeros(6, 7), jnp.zeros(6, 7))  # AssertionError!

也可以用于jax.pmap()

  def fn_sub(x, y):
    return x - y

  fn_sub_pmapped = jax.pmap(chex.assert_max_traces(fn_sub, n=10))

有关跟踪的更多信息,请参阅 JAX 原语如何工作 部分。

测试变体(variants.py

JAX 广泛依赖于代码转换和编译,这意味着很难确保代码得到正确测试。例如,仅使用 JAX 代码测试 python 函数将不会覆盖 jitted 时执行的实际代码路径,并且无论代码是针对 CPU、GPU 还是 TPU 进行 jitted,该路径也会有所不同。这是一个晦涩难懂且难以捕捉的 bug 的来源,其中 XLA 更改会导致不良行为,但仅在一个特定的代码转换中表现出来。

变体通过提供一个简单的装饰器,可以很容易地确保单元测试涵盖函数的不同“变体”,该装饰器可用于在所有(或子集)相关代码转换下重复任何测试。

例如,假设您想测试fn有或没有 jit 的函数的输出。您可以使用chex.variants函数的 jitted 和非 jitted 版本运行测试,只需用 装饰测试方法 @chex.variants,然后在测试主体中使用self.variant(fn)代替。fn

def fn(x, y):
  return x + y
...

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    var_fn = self.variant(fn)
    self.assertEqual(fn(1, 2), 3)
    self.assertEqual(var_fn(1, 2), fn(1, 2))

如果在测试方法中定义函数,也可以self.variant 在函数定义中作为装饰器使用。例如:

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(1, 2), 3)

参数化测试示例:

from absl.testing import parameterized

# Could also be:
#  `class ExampleParameterizedTest(chex.TestCase, parameterized.TestCase):`
#  `class ExampleParameterizedTest(chex.TestCase):`
class ExampleParameterizedTest(parameterized.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  @parameterized.named_parameters(
      ('case_positive', 1, 2, 3),
      ('case_negative', -1, -2, -3),
  )
  def test(self, arg_1, arg_2, expected):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(arg_1, arg_2), expected)

Chex 当前支持以下变体:

  • with_jit--jax.jit()对函数应用转换。
  • without_jit-- 按原样使用函数,即恒等变换。
  • with_device-- 在应用函数之前,将所有参数(参数中指定的除外ignore_argnums )放入设备内存中。
  • without_device-- 在应用函数之前将所有参数放在 RAM 中。
  • with_pmap--jax.pmap()对函数应用转换(见下面的注释)。

有关支持的变体的更多详细信息,请参阅 variables.py 中的文档。更多示例可以在variables_test.py中找到。

变体注释

  • 使用的测试类@chex.variants必须继承自 chex.TestCase(或在 中展开测试生成器的任何其他基类TestCase,例如absl.testing.parameterized.TestCase)。

  • [ jax.vmap]所有变体都可以应用于 vmapped 函数;请参阅variables_test.py (test_vmapped_fn_named_paramstest_pmap_vmapped_fn)中的示例。

  • [ @chex.all_variants]您可以使用装饰器获得所有支持的变体@chex.all_variants

  • [with_pmap变体] jax.pmap(fn) ( docfn )在多个设备上执行并行映射。由于大多数测试在单设备环境中运行(即可以访问单个 CPU 或 GPU),在这种情况下,它jax.pmap的功能等同于jax.jit with_pmap因此默认情况下会跳过变体(尽管它适用于单个设备)。下面我们描述了一种正确测试fn它是否应该在多设备环境(TPU 或多个 CPU/GPU)中使用的方法。with_pmap要在单个设备的情况下 禁用跳过 变体,请添加--chex_skip_pmap_variant_if_single_device=false到您的测试命令。

假货 ( fake.py )

代码转换(例如jit 和)使 JAX 中的调试变得更加困难pmap,这些转换引入了使代码难以检查和跟踪的优化。在调试期间禁用这些转换也很困难,因为它们可以在底层代码的多个位置调用。Chex 提供了jax.jit用 no-op 转换和jax.pmap(non-parallel)全局替换的工具jax.vmap,以便在单设备上下文中更轻松地调试代码。

例如,您可以使用 Chex 进行伪造pmap并将其替换为vmap. 这可以通过使用上下文管理器包装您的代码来实现:

with chex.fake_pmap():
  @jax.pmap
  def fn(inputs):
    ...

  # Function will be vmapped over inputs
  fn(inputs)

也可以使用startand调用相同的功能stop

fake_pmap = chex.fake_pmap()
fake_pmap.start()
... your jax code ...
fake_pmap.stop()

此外,您可以使用多线程 CPU 伪造一个真正的多设备测试环境。有关更多详细信息,请参阅伪造多设备测试环境部分。

有关更多详细信息,请参阅 fake.py 中的文档fake_test.py中的示例。

伪造多设备测试环境

在您无法轻松访问多个设备的情况下,您仍然可以使用单设备多线程测试并行计算。

特别是,可以强制 XLA 将单个 CPU 的线程用作单独的设备,即用多线程环境伪造一个真正的多设备环境。从 XLA 的角度来看,这两个选项在理论上是等效的,因为它们公开了相同的接口并使用了相同的抽象。

Chex 有一个标志chex_n_cpu_devices,用于指定用作 XLA 设备的 CPU 线程数。

要为测试设置多线程 XLA 环境,请在测试模块中absl定义 setUpModule函数:

def setUpModule():
  chex.set_n_cpu_devices()

现在您可以启动您的测试以python test.py --chex_n_cpu_devices=N在多设备状态下运行它。请注意,模块中的所有测试都可以访问N设备。

更多示例可以在variables_test.pyfake_test.pyfake_set_n_cpu_devices_test.py中找到。

使用命名的尺寸大小。

Chex 带有一个小实用程序,允许您将一组尺寸大小打包到一个对象中。基本思想是:

dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim)
...
chex.assert_shape(arr, dims['BTE'])

字符串查找是翻译的整数元组。例如,假设 batch_size == 3,sequence_len = 5embedding_dim = 7, 那么

dims['BTE'] == (3, 5, 7)
dims['B'] == (3,)
dims['TTBEE'] == (5, 5, 3, 7, 7)
...

您还可以动态分配尺寸大小,如下所示:

dims['XY'] = some_matrix.shape
dims.Z = 13

有关更多示例,请参阅chex.Dimensions 文档。

引用 Chex

此存储库是DeepMind JAX 生态系统的一部分,要引用 Chex,请使用DeepMind JAX 生态系统引用

项目详情