bartz.mcmcstep.PallasReduction¶
- class bartz.mcmcstep.PallasReduction(block_size='auto', num_blocks='auto', auto_gpu_target=1024, backend='triton')[source]¶
Blocked one-hot scatter-add written as a Pallas kernel.
Splits the datapoints into blocks and, for each block, contracts the values against a one-hot leaf-membership matrix held in fast memory, accumulating the block partials. Unlike
OneHotReduction, the one-hot product is guaranteed to stay fused (it is never written back to main memory). Targets gpu/tpu; on cpu it falls back to Pallas interpret mode, which is slow and meant only for testing. LikeOneHotReduction, it is competitive only when the number of output bins is small. Does not support sharding the datapoints across devices.On gpu the kernel is lowered through Triton or Mosaic GPU; see
backend.- block_size: int | Literal['auto'] = 'auto'¶
Datapoints contracted per kernel iteration, i.e., the width of the one-hot tile in fast memory. If ‘auto’, chosen to keep that tile small. Should be a power of 2 on gpu.
- num_blocks: int | Literal['auto'] = 'auto'¶
Number of kernel instances (grid size) the datapoints are split across, each looping over its share. More instances raise occupancy but enlarge the partial-sum buffer. If ‘auto’, resolved per-platform at trace time.
- auto_gpu_target: int = 1024¶
Cap on the number of kernel instances on gpu when
num_blocksis ‘auto’.
- backend: Literal['triton', 'cpu', 'default'] = 'triton'¶
How to lower the kernel. The run platform is not known here at trace time, so it cannot be selected automatically:
- ‘triton’
Pass Triton compiler params; the default, compiles on every CUDA/ROCm gpu.
- ‘cpu’
Pallas interpret mode, the only mode that runs on cpu (slow; for testing).
- ‘default’
Pass nothing, leaving jax to pick its own gpu backend: that is Mosaic GPU, which only compiles on Hopper and newer (compute capability 9.0+).