JAX extensions

bartz.jaxext

Additions to jax.

bartz.jaxext.vmap_nodoc(fun, *args, **kw)[source]

Acts like jax.vmap but preserves the docstring of the function unchanged.

This is useful if the docstring already takes into account that the arguments have additional axes due to vmap.

bartz.jaxext.minimal_unsigned_dtype(value)[source]

Return the smallest unsigned integer dtype that can represent value.

bartz.jaxext.unique(x, size, fill_value)[source]

Restricted version of jax.numpy.unique that uses less memory.

Parameters:
  • x (Shaped[Array, '_']) – The input array.

  • size (int) – The length of the output.

  • fill_value (Shaped[Array, '']) – The value to fill the output with if size is greater than the number of unique values in x.

Returns:

  • out (Shaped[Array, ‘{size}’]) – The unique values in x, sorted, and right-padded with fill_value.

  • actual_length (int) – The number of used values in out.

class bartz.jaxext.split(key, num=2)[source]

Split a key into num keys.

Parameters:
  • key (Key[Array, '*batch']) – The key to split.

  • num (int, default: 2) – The number of keys to split into.

Notes

Unlike jax.random.split, this class supports a vector of keys as input. In this case, it behaves as if everything had been vmapped over, so keys.pop has an additional initial output dimension equal to the number of input keys, and the deterministic dependency respects this axis.

pop(shape=())[source]

Pop one or more keys from the list.

Parameters:

shape (int | tuple[int, ...], default: ()) – The shape of the keys to pop. If empty (default), a single key is popped and returned. If not empty, the popped key is split and reshaped to the target shape.

Returns:

Key[Array, '*batch {shape}']The popped keys as a jax array with the requested shape.

Raises:

IndexError – If the list is empty.

bartz.jaxext.truncated_normal_onesided(key, shape, upper, bound, *, clip=True)[source]

Sample from a one-sided truncated standard normal distribution.

Parameters:
  • key (Key[Array, '']) – JAX random key.

  • shape (Sequence[int]) – Shape of output array, broadcasted with other inputs.

  • upper (Bool[Array, '*']) – True for (-∞, bound], False for [bound, ∞).

  • bound (Float32[Array, '*']) – The truncation boundary.

  • clip (bool, default: True) – Whether to clip the truncated uniform samples to (0, 1) before transforming them to truncated normal. Intended for debugging purposes.

Returns:

Float32[Array, '*']Array of samples from the truncated normal distribution.

bartz.jaxext.get_default_device()[source]

Get the current default JAX device.

Return type:

Device

bartz.jaxext.get_device_count()[source]

Get the number of available devices on the default platform.

Return type:

int

bartz.jaxext.is_key(x)[source]

Determine if x is a jax random key.

Return type:

bool

bartz.jaxext.autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, *, return_nbatches=False, reduce_ufunc=None, warn_on_overflow=True, result_shape_dtype=NotDefined)[source]

Batch a function such that each batch is smaller than a threshold.

Parameters:
  • func (Callable) – A jittable function with positional arguments only, with inputs and outputs pytrees of arrays.

  • max_io_nbytes (int) – The maximum number of input + output bytes in each batch (excluding unbatched arguments.)

  • in_axes (PyTree[int | None], default: 0) – A tree matching (a prefix of) the structure of the function input, indicating along which axes each array should be batched. A None axis indicates to not batch an argument.

  • out_axes (PyTree[int], default: 0) – The same for outputs (but non-batching is not allowed).

  • return_nbatches (bool, default: False) – If True, the number of batches is returned as a second output.

  • reduce_ufunc (ufunc | None, default: None) – Function used to reduce the output along the batched axis (e.g., jax.numpy.add).

  • warn_on_overflow (bool, default: True) – If True, a warning is raised if the memory limit could not be respected.

  • result_shape_dtype (ShapeDtypeStruct], default: NotDefined) – A pytree of dummy arrays matching the expected output. If not provided, the function is traced an additional time to determine the output structure.

Returns:

Callable – A function with the same signature as func, save for the return value if return_nbatches.

Notes

Unless return_nbatches or reduce_ufunc are set, autobatch at given arguments is idempotent. Furthermore, autobatch can be applied multiple times over multiple axes with the same max_io_nbytes limit to work on multiple axes; in this case it won’t unnecessarily loop over additional axes if one or more outer autobatch are already sufficient.

To handle memory used in intermediate values: assuming all intermediate values have size that scales linearly with the axis batched over, say the batched input/output total size is batched_size * core_io_size, and the intermediate values have size batched_size * core_int_size, then to take them into account divide max_io_nbytes by (1 + core_int_size / core_io_size).

bartz.jaxext.scipy.special

Mockup of the scipy.special module.

bartz.jaxext.scipy.special.gammainccinv(a, y)[source]

Survival function inverse of the Gamma(a, 1) distribution.

bartz.jaxext.scipy.special.ndtri(p)[source]

Compute the inverse of the CDF of the Normal distribution function.

This is a patch of jax.scipy.special.ndtri.

bartz.jaxext.scipy.stats

Mockup of the scipy.stats module.

class bartz.jaxext.scipy.stats.invgamma[source]

Class that represents the distribution InvGamma(a, 1).

static ppf(q, a)[source]

Percentile point function.