JAX extensions¶
bartz.jaxext¶
Additions to jax.
- bartz.jaxext.vmap_nodoc(fun, *args, **kw)[source]¶
Acts like
jax.vmapbut preserves the docstring of the function unchanged.This is useful if the docstring already takes into account that the arguments have additional axes due to vmap.
- Return type:
- bartz.jaxext.minimal_unsigned_dtype(value)[source]¶
Return the smallest unsigned integer dtype that can represent
value.- Return type:
- bartz.jaxext.unique(x, size, fill_value)[source]¶
Restricted version of
jax.numpy.uniquethat uses less memory.- Parameters:
x (
Shaped[Array, '_']) – The input array.size (
int) – The length of the output.fill_value (
Shaped[Array, '']) – The value to fill the output with ifsizeis greater than the number of unique values inx.
- Returns:
out (Shaped[Array, ‘{size}’]) – The unique values in
x, sorted, and right-padded withfill_value.actual_length (int) – The number of used values in
out.
- class bartz.jaxext.split(key, num=2)[source]¶
Split a key into
numkeys.- Parameters:
key (
Key[Array, '']) – The key to split.num (
int, default:2) – The number of keys to split into.
- pop(shape=())[source]¶
Pop one or more keys from the list.
- Parameters:
shape (
DTypeLike[int,KeyPath[int,...]], default:()) – The shape of the keys to pop. If empty (default), a single key is popped and returned. If not empty, the popped key is split and reshaped to the target shape.- Returns:
Key[Array, '{shape}']– The popped keys as a jax array with the requested shape.- Raises:
IndexError – If the list is empty.
- bartz.jaxext.truncated_normal_onesided(key, shape, upper, bound, *, clip=True)[source]¶
Sample from a one-sided truncated standard normal distribution.
- Parameters:
key (
Key[Array, '']) – JAX random key.shape (
Sequence[int]) – Shape of output array, broadcasted with other inputs.upper (
Bool[Array, '*']) – True for (-∞, bound], False for [bound, ∞).bound (
Float32[Array, '*']) – The truncation boundary.clip (
bool, default:True) – Whether to clip the truncated uniform samples to (0, 1) before transforming them to truncated normal. Intended for debugging purposes.
- Returns:
Float32[Array, '*']– Array of samples from the truncated normal distribution.
- bartz.jaxext.get_device_count()[source]¶
Get the number of available devices on the default platform.
- Return type:
int
- bartz.jaxext.equal_shards(x, axis_name, **shard_map_kwargs)[source]¶
Check that all shards of
xare equal across axisaxis_name.- Parameters:
x (
Array, 'S']) – A pytree of arrays to check. Each array is checked separately.axis_name (
str) – The mesh axis name across which equality is checked. It’s not checked across other axes.**shard_map_kwargs (
Any) – Additional arguments passed tojax.shard_mapto set up the function that checks equality. You may need to specifyin_specspassing the (pytree of)jax.sharding.PartitionSpecthat specifies howxis sharded, if the axes are not explicit, andmeshif there is not a default mesh set byjax.set_mesh.
- Returns:
Bool[Array, ''], 'S']– A pytree of booleans indicating whether each leaf is equal across devices along the mesh axis.
- bartz.jaxext.autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, *, return_nbatches=False, reduce_ufunc=None, warn_on_overflow=True, result_shape_dtype=NotDefined)[source]¶
Batch a function such that each batch is smaller than a threshold.
- Parameters:
func (
Callable) – A jittable function with positional arguments only, with inputs and outputs pytrees of arrays.max_io_nbytes (
int) – The maximum number of input + output bytes in each batch (excluding unbatched arguments.)in_axes (
PyTree[int | None], default:0) – A tree matching (a prefix of) the structure of the function input, indicating along which axes each array should be batched. ANoneaxis indicates to not batch an argument.out_axes (
PyTree[int], default:0) – The same for outputs (but non-batching is not allowed).return_nbatches (
bool, default:False) – If True, the number of batches is returned as a second output.reduce_ufunc (
DTypeLike[ufunc,None], default:None) – Function used to reduce the output along the batched axis (e.g.,jax.numpy.add).warn_on_overflow (
bool, default:True) – If True, a warning is raised if the memory limit could not be respected.result_shape_dtype (
ShapeDtypeStruct], default:NotDefined) – A pytree of dummy arrays matching the expected output. If not provided, the function is traced an additional time to determine the output structure.
- Returns:
Callable– A function with the same signature asfunc, save for the return value ifreturn_nbatches.
Notes
Unless
return_nbatchesorreduce_ufuncare set,autobatchat given arguments is idempotent. Furthermore,autobatchcan be applied multiple times over multiple axes with the samemax_io_nbyteslimit to work on multiple axes; in this case it won’t unnecessarily loop over additional axes if one or more outerautobatchare already sufficient.To handle memory used in intermediate values: assuming all intermediate values have size that scales linearly with the axis batched over, say the batched input/output total size is
batched_size * core_io_size, and the intermediate values have sizebatched_size * core_int_size, then to take them into account dividemax_io_nbytesby(1 + core_int_size / core_io_size).
bartz.jaxext.scipy.special¶
Mockup of the scipy.special module.
- bartz.jaxext.scipy.special.gammainccinv(a, y)[source]¶
Survival function inverse of the Gamma(a, 1) distribution.
- Return type:
Float[Array, '*']
- bartz.jaxext.scipy.special.ndtri(p)[source]¶
Compute the inverse of the CDF of the Normal distribution function.
This is a patch of
jax.scipy.special.ndtri.- Return type:
Float[Array, '*']
bartz.jaxext.scipy.stats¶
Mockup of the scipy.stats module.