JAX extensions¶
bartz.jaxext¶
Additions to jax.
- bartz.jaxext.vmap_nodoc(fun, *args, **kw)[source]¶
Acts like
jax.vmapbut preserves the docstring of the function unchanged.This is useful if the docstring already takes into account that the arguments have additional axes due to vmap.
- bartz.jaxext.minimal_unsigned_dtype(value)[source]¶
Return the smallest unsigned integer dtype that can represent
value.
- bartz.jaxext.unique(x, size, fill_value)[source]¶
Restricted version of
jax.numpy.uniquethat uses less memory.- Parameters:
x (
Shaped[Array, '_']) – The input array.size (
int) – The length of the output.fill_value (
Shaped[Array, '']) – The value to fill the output with ifsizeis greater than the number of unique values inx.
- Returns:
out (Shaped[Array, ‘{size}’]) – The unique values in
x, sorted, and right-padded withfill_value.actual_length (int) – The number of used values in
out.
- class bartz.jaxext.split(key, num=2)[source]¶
Split a key into
numkeys.- Parameters:
key (
Key[Array, '*batch']) – The key to split.num (
int, default:2) – The number of keys to split into.
Notes
Unlike
jax.random.split, this class supports a vector of keys as input. In this case, it behaves as if everything had been vmapped over, sokeys.pophas an additional initial output dimension equal to the number of input keys, and the deterministic dependency respects this axis.- pop(shape=())[source]¶
Pop one or more keys from the list.
- Parameters:
shape (
int|tuple[int,...], default:()) – The shape of the keys to pop. If empty (default), a single key is popped and returned. If not empty, the popped key is split and reshaped to the target shape.- Returns:
Key[Array, '*batch {shape}']– The popped keys as a jax array with the requested shape.- Raises:
IndexError – If the list is empty.
- bartz.jaxext.truncated_normal_onesided(key, shape, upper, bound, *, clip=True)[source]¶
Sample from a one-sided truncated standard normal distribution.
- Parameters:
key (
Key[Array, '']) – JAX random key.shape (
Sequence[int]) – Shape of output array, broadcasted with other inputs.upper (
Bool[Array, '*']) – True for (-∞, bound], False for [bound, ∞).bound (
Float32[Array, '*']) – The truncation boundary.clip (
bool, default:True) – Whether to clip the truncated uniform samples to (0, 1) before transforming them to truncated normal. Intended for debugging purposes.
- Returns:
Float32[Array, '*']– Array of samples from the truncated normal distribution.
- bartz.jaxext.get_device_count()[source]¶
Get the number of available devices on the default platform.
- Return type:
int
- bartz.jaxext.autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, *, return_nbatches=False, reduce_ufunc=None, warn_on_overflow=True, result_shape_dtype=NotDefined)[source]¶
Batch a function such that each batch is smaller than a threshold.
- Parameters:
func (
Callable) – A jittable function with positional arguments only, with inputs and outputs pytrees of arrays.max_io_nbytes (
int) – The maximum number of input + output bytes in each batch (excluding unbatched arguments.)in_axes (
PyTree[int | None], default:0) – A tree matching (a prefix of) the structure of the function input, indicating along which axes each array should be batched. ANoneaxis indicates to not batch an argument.out_axes (
PyTree[int], default:0) – The same for outputs (but non-batching is not allowed).return_nbatches (
bool, default:False) – If True, the number of batches is returned as a second output.reduce_ufunc (
ufunc|None, default:None) – Function used to reduce the output along the batched axis (e.g.,jax.numpy.add).warn_on_overflow (
bool, default:True) – If True, a warning is raised if the memory limit could not be respected.result_shape_dtype (
ShapeDtypeStruct], default:NotDefined) – A pytree of dummy arrays matching the expected output. If not provided, the function is traced an additional time to determine the output structure.
- Returns:
Callable– A function with the same signature asfunc, save for the return value ifreturn_nbatches.
Notes
Unless
return_nbatchesorreduce_ufuncare set,autobatchat given arguments is idempotent. Furthermore,autobatchcan be applied multiple times over multiple axes with the samemax_io_nbyteslimit to work on multiple axes; in this case it won’t unnecessarily loop over additional axes if one or more outerautobatchare already sufficient.To handle memory used in intermediate values: assuming all intermediate values have size that scales linearly with the axis batched over, say the batched input/output total size is
batched_size * core_io_size, and the intermediate values have sizebatched_size * core_int_size, then to take them into account dividemax_io_nbytesby(1 + core_int_size / core_io_size).
bartz.jaxext.scipy.special¶
Mockup of the scipy.special module.
- bartz.jaxext.scipy.special.gammainccinv(a, y)[source]¶
Survival function inverse of the Gamma(a, 1) distribution.
- bartz.jaxext.scipy.special.ndtri(p)[source]¶
Compute the inverse of the CDF of the Normal distribution function.
This is a patch of
jax.scipy.special.ndtri.
bartz.jaxext.scipy.stats¶
Mockup of the scipy.stats module.