bartz.mcmcstep.BatchedReduction

class bartz.mcmcstep.BatchedReduction(num_batches=None, batches_inner=True, contiguous=False)[source]

Segment-sum with optional batching along the datapoints.

Fastest at the usual tree sizes. See AutoBatchedReduction to resolve num_batches automatically per platform.

num_batches: int | None = None

The number of datapoint batches, or None (the default) for an unbatched reduce.

batches_inner: bool = True

Whether the batch axis sits on the scatter buffer’s inner, contiguous axis (size-by-num_batches) or its outer axis (num_batches-by-size); the two layouts give the backend different memory access patterns. True (the default) matches the historical layout. No effect when num_batches is None.

contiguous: bool = False

How datapoints are assigned to batches. False (the default) strides them, sending datapoint i to batch i % num_batches; True splits them into contiguous chunks, sending i to batch i // batch_size. No effect when num_batches is None.