Coverage for src/bartz/mcmcstep/_reduction.py: 99%
244 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-07-02 09:03 +0000
1# bartz/src/bartz/mcmcstep/_reduction.py
2#
3# Copyright (c) 2024-2026, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Indexed-reduce (scatter-add) configs, one per algorithm, and the core ops."""
27import math
28from abc import abstractmethod
29from functools import partial
30from typing import Literal
32import jax
33from equinox import Module, field
34from jax import ShapeDtypeStruct, lax
35from jax import numpy as jnp
36from jax.experimental import pallas as pl
37from jax.extend.backend import backends
38from jax.typing import DTypeLike
39from jaxtyping import Array, Float, Int32, Integer, Shaped, UInt
41# target number of datapoint batches on cpu, and minimum datapoints per batch,
42# when batching is resolved automatically; unlike the gpu heuristic these are
43# flat (the cpu has no SM-count analog to scale with)
44_AUTO_CPU_TARGET = 16
45_AUTO_CPU_MIN_BATCH = 32
47# target number of elements in the per-instance one-hot tile of `PallasReduction`
48# when its block size is resolved automatically
49_AUTO_PALLAS_TILE = 2**12
51# SM count used to trace `AutoBatchedReduction`'s gpu branch when no cuda backend
52# is visible. `lax.platform_dependent` traces that branch even with no gpu, only
53# to discard it at lowering, so this value never sizes a real gpu's batch grid;
54# it just keeps the dead trace valid (a present cuda backend reports its own).
55_MOOT_GPU_SM = 1
58class ReductionConfig(Module):
59 """Select and configure an indexed-reduce (scatter-add) implementation.
61 Each concrete subclass identifies a reduction algorithm and carries its
62 options. Pass instances to `init` to control how the residuals, counts and
63 likelihood precisions are summed over the datapoints in each leaf.
64 """
66 @abstractmethod
67 def _reduce(
68 self,
69 values: Float[Array, '*batch_shape n'] | int,
70 indices: UInt[Array, ' n'],
71 /,
72 *,
73 size: int,
74 subset_start: Integer[Array, ''] | None = None,
75 subset_length: int | None = None,
76 dtype: DTypeLike,
77 data_sharded: bool,
78 # the output's trailing axis is the number of reduced bins: the range
79 # length, or `size` without a subset. jaxtyping evals the `{...}` dim
80 # against the arguments but forbids spaces and str-formats arrays, so
81 # this indexes a tuple by a bool instead of `... if ... else ...`. The
82 # bool reads a zero-length range as no subset, mislabeling the dim, but
83 # the only caller passes the nonempty two-element child pair.
84 ) -> Shaped[Array, '*batch_shape {(size,subset_length)[bool(subset_length)]}']:
85 """Indexed reduce along the last axis of `values`.
87 Parameters
88 ----------
89 values
90 The values to sum into bins, or a scalar `int` weighting every
91 datapoint equally (used to count the datapoints in each bin).
92 indices
93 The bin index each datapoint falls into.
94 size
95 The static excluded upper bound on the values of `indices`, and
96 the number of output bins when no subset is given.
97 subset_start
98 If given (with `subset_length`), reduce only into the contiguous
99 bin range ``[subset_start, subset_start + subset_length)``,
100 ignoring datapoints whose index falls elsewhere. The range may run
101 past `size`; those out-of-domain bins reduce to zero.
102 subset_length
103 The static length of the bin range, or `None` to reduce into all
104 `size` bins.
105 dtype
106 The dtype of the output and of the accumulation; the values are
107 kept in their own, possibly narrower, dtype until accumulated.
108 data_sharded
109 Whether the data axis is sharded; if true, the result is
110 psum-reduced across the ``'data'`` axis of the enclosing
111 `shard_map`.
113 Returns
114 -------
115 The per-bin sums, with the same leading dimensions as `values` and the bins on the trailing axis.
116 """
117 ...
120def _resolve_range(
121 indices: UInt[Array, ' n'],
122 size: int,
123 subset_start: Integer[Array, ''] | None,
124 subset_length: int | None,
125) -> tuple[int, UInt[Array, ' n']]:
126 """Reduce the contiguous-range subset to the full case for scatter algorithms.
128 Parameters
129 ----------
130 indices
131 The bin index each datapoint falls into, in ``[0, size)``.
132 size
133 The number of bins.
134 subset_start
135 The first bin of the range to reduce into, or `None` for all bins.
136 subset_length
137 The static number of bins in the range, or `None` for all bins.
139 Returns
140 -------
141 out_size : int
142 The number of output bins: `subset_length`, or `size` without a subset.
143 indices : UInt[Array, ' n']
144 The scatter indices into the output bins: unchanged without a subset,
145 else each datapoint's offset from `subset_start`, in the indices' own
146 unsigned dtype, so that indices outside ``[subset_start, subset_start +
147 subset_length)`` land out of bounds, where the scatter drops them.
148 """
149 if subset_length is None:
150 return size, indices
151 else:
152 # the subtraction is unsigned: indices below `subset_start` underflow to
153 # a large value rather than going negative (which the scatter would read
154 # as wrap-around indexing), and together with indices ``>= subset_start +
155 # subset_length`` they fall outside the output, where scatters drop them.
156 # Exact while the range fits the index dtype (``subset_start +
157 # subset_length <= 2 ** bits``), as it does for the per-move child pair.
158 assert subset_start is not None # set together with subset_length
159 assert jnp.issubdtype(indices.dtype, jnp.unsignedinteger)
160 offset = indices - subset_start.astype(indices.dtype)
161 return subset_length, offset
164class BatchedReduction(ReductionConfig):
165 """Segment-sum with optional batching along the datapoints.
167 Fastest at the usual tree sizes. See `AutoBatchedReduction` to resolve
168 `num_batches` automatically per platform.
169 """
171 num_batches: int | None = field(static=True, default=None)
172 """The number of datapoint batches, or `None` (the default) for an unbatched
173 reduce."""
175 batches_inner: bool = field(static=True, default=True)
176 """Whether the batch axis sits on the scatter buffer's inner, contiguous axis
177 (``size``-by-``num_batches``) or its outer axis (``num_batches``-by-``size``);
178 the two layouts give the backend different memory access patterns. `True` (the
179 default) matches the historical layout. No effect when `num_batches` is `None`."""
181 contiguous: bool = field(static=True, default=False)
182 """How datapoints are assigned to batches. `False` (the default) strides them,
183 sending datapoint ``i`` to batch ``i % num_batches``; `True` splits them into
184 contiguous chunks, sending ``i`` to batch ``i // batch_size``. No effect when
185 `num_batches` is `None`."""
187 def _reduce(
188 self,
189 values: Float[Array, '*batch_shape n'] | int,
190 indices: UInt[Array, ' n'],
191 /,
192 *,
193 size: int,
194 subset_start: Integer[Array, ''] | None = None,
195 subset_length: int | None = None,
196 dtype: DTypeLike,
197 data_sharded: bool,
198 ) -> Shaped[Array, '*batch_shape {(size,subset_length)[bool(subset_length)]}']:
199 values = jnp.asarray(values)
200 assert values.ndim == 0 or values.shape[-1:] == indices.shape
201 size, indices = _resolve_range(indices, size, subset_start, subset_length)
202 batch_shape = values.shape[:-1]
204 if self.num_batches is None:
205 out = jnp.zeros((*batch_shape, size), dtype).at[..., indices].add(values)
206 else:
207 # in the sharded case, n is the size of the local shard, not the full size
208 (n,) = indices.shape
209 # unsigned avoids a negative-index normalization select in the scatter
210 iota = jnp.arange(n, dtype=jnp.uint32)
211 if self.contiguous:
212 batch_size = -(-n // self.num_batches) # ceil: last batch is partial
213 batch_indices = iota // batch_size
214 else:
215 batch_indices = iota % self.num_batches
216 if self.batches_inner:
217 out = (
218 jnp.zeros((*batch_shape, size, self.num_batches), dtype)
219 .at[..., indices, batch_indices]
220 .add(values)
221 .sum(axis=-1)
222 )
223 else:
224 out = (
225 jnp.zeros((*batch_shape, self.num_batches, size), dtype)
226 .at[..., batch_indices, indices]
227 .add(values)
228 .sum(axis=-2)
229 )
231 if data_sharded:
232 out = lax.psum(out, 'data')
233 return out
236class AutoBatchedReduction(ReductionConfig):
237 """`BatchedReduction` that picks `num_batches` automatically per platform.
239 A flat target on cpu, and on gpu a count scaling with the SM count and the
240 multivariate outcome size. Only cpu and cuda are supported; any other
241 platform raises at lowering.
242 """
244 min_batch_size: float = field(static=True, default=128.0)
245 """Minimum datapoints per batch on gpu: caps the batch count at
246 ``n / min_batch_size``."""
248 beta_sm: float = field(static=True, default=48.0)
249 """Batches per streaming multiprocessor on gpu: the batch count saturates at
250 ``beta_sm * n_sms * m ** -gamma``, with `n_sms` the gpu's SM count and ``m``
251 the multivariate work per datapoint."""
253 gamma: float = field(static=True, default=0.4)
254 """Exponent by which multivariate outcomes (``m`` values per datapoint) shrink
255 the saturation batch count on gpu."""
257 batches_inner: bool = field(static=True, default=True)
258 """Same as `BatchedReduction`."""
260 contiguous: bool = field(static=True, default=False)
261 """Same as `BatchedReduction`."""
263 def _reduce(
264 self,
265 values: Float[Array, '*batch_shape n'] | int,
266 indices: UInt[Array, ' n'],
267 /,
268 *,
269 size: int,
270 subset_start: Integer[Array, ''] | None = None,
271 subset_length: int | None = None,
272 dtype: DTypeLike,
273 data_sharded: bool,
274 ) -> Shaped[Array, '*batch_shape {(size,subset_length)[bool(subset_length)]}']:
275 # defer the cpu/gpu choice to XLA: both branches are traced, but only the
276 # one for the run platform is lowered. With no `default`, an untested
277 # platform (rocm, tpu) errors at lowering instead of silently falling back.
278 kwargs = dict(
279 size=size,
280 subset_start=subset_start,
281 subset_length=subset_length,
282 dtype=dtype,
283 data_sharded=data_sharded,
284 )
285 return lax.platform_dependent(
286 cpu=partial(self._reduce_cpu, values, indices, **kwargs),
287 cuda=partial(self._reduce_gpu, values, indices, **kwargs),
288 )
290 def _reduce_cpu(
291 self,
292 values: Float[Array, '*batch_shape n'] | int,
293 indices: UInt[Array, ' n'],
294 /,
295 *,
296 size: int,
297 subset_start: Integer[Array, ''] | None = None,
298 subset_length: int | None = None,
299 dtype: DTypeLike,
300 data_sharded: bool,
301 ) -> Shaped[Array, '*batch_shape {(size,subset_length)[bool(subset_length)]}']:
302 # flat target: the cpu has no SM-count analog to scale the batch count with
303 (n,) = indices.shape
304 num_batches = _final_round(n, _AUTO_CPU_TARGET, _AUTO_CPU_MIN_BATCH)
305 return self._delegate(num_batches)._reduce( # noqa: SLF001
306 values,
307 indices,
308 size=size,
309 subset_start=subset_start,
310 subset_length=subset_length,
311 dtype=dtype,
312 data_sharded=data_sharded,
313 )
315 def _reduce_gpu(
316 self,
317 values: Float[Array, '*batch_shape n'] | int,
318 indices: UInt[Array, ' n'],
319 /,
320 *,
321 size: int,
322 subset_start: Integer[Array, ''] | None = None,
323 subset_length: int | None = None,
324 dtype: DTypeLike,
325 data_sharded: bool,
326 ) -> Shaped[Array, '*batch_shape {(size,subset_length)[bool(subset_length)]}']:
327 # `n` is the local shard size when data-sharded, which is exactly what the
328 # heuristic wants. `m` is the multivariate batch size: the product of the
329 # values' leading (non-datapoint) axes (1 for the scalar count case), i.e.
330 # how much vector work rides on each scatter slot; clamped to >=1 so a
331 # zero-size leading axis (e.g. k=0 outcome components) leaves a moot, empty
332 # reduce with the m=1 baseline cap.
333 (n,) = indices.shape
334 m = max(1, math.prod(jnp.shape(values)[:-1]))
335 # the gpu cap scales up with the SM count and down, sublinearly, with the
336 # multivariate work per slot
337 sm_cap = self.beta_sm * _gpu_sm_count() * m ** (-self.gamma)
338 num_batches = _final_round(n, sm_cap, self.min_batch_size)
339 return self._delegate(num_batches)._reduce( # noqa: SLF001
340 values,
341 indices,
342 size=size,
343 subset_start=subset_start,
344 subset_length=subset_length,
345 dtype=dtype,
346 data_sharded=data_sharded,
347 )
349 def _delegate(self, num_batches: int | None) -> BatchedReduction:
350 """Build a `BatchedReduction` with the resolved count and this config's layout."""
351 return BatchedReduction(
352 num_batches=num_batches,
353 batches_inner=self.batches_inner,
354 contiguous=self.contiguous,
355 )
358def _final_round(
359 n: int, target: float | int, min_batch_size: float | int
360) -> int | None:
361 """Cap batches to keep them above `min_batch_size`, round to a power of 2, and disable batching if there's only 1 batch."""
362 # at least `min_batch_size` elements per batch
363 num = min(n / min_batch_size, target)
365 # round to the nearest power of 2 because I guess XLA and the hardware
366 # will like that (not sure about this, maybe just multiple of 32?)
367 num = 2 ** round(math.log2(num)) if num > 0 else 0
369 # disable batching if the batch is as large as the whole dataset
370 return num if num > 1 else None
373def _gpu_sm_count() -> int:
374 """Streaming-multiprocessor count shared by the visible cuda gpus.
376 Read by `AutoBatchedReduction` to size the gpu batch grid. Since
377 `lax.platform_dependent` only runs the gpu branch on cuda, this trusts
378 `jax.devices('cuda')` and each device's `core_count` rather than guessing,
379 and raises if the gpus report differing counts (a mixed gpu set is
380 unsupported).
381 """
382 if 'cuda' not in backends():
383 # no cuda backend: lax.platform_dependent still traces the gpu branch
384 # here, only to discard it at lowering, so the count is never used
385 return _MOOT_GPU_SM
386 counts = {device.core_count for device in jax.devices('cuda')}
387 if len(counts) > 1:
388 msg = (
389 f'visible cuda gpus report differing SM counts {sorted(counts)}; '
390 'AutoBatchedReduction assumes a single gpu model'
391 )
392 raise ValueError(msg)
393 (count,) = counts
394 return count
397class OneHotReduction(ReductionConfig):
398 """Dense one-hot reduction.
400 Materializes the membership of each datapoint in its leaf as a one-hot
401 matrix over the output bins and contracts it against the values. Beats
402 `BatchedReduction` only when the number of bins is very small (e.g. a
403 single leaf pair), or on gpu for multivariate residuals.
404 """
406 method: Literal['matmul', 'multiply', 'scatter_set'] = field(
407 static=True, default='matmul'
408 )
409 """How to contract the values against the one-hot leaf-membership matrix:
411 'matmul'
412 Contract the values with the one-hot matrix via a dot. Faster on gpu,
413 especially for multivariate residuals.
414 'multiply'
415 Elementwise-multiply by the one-hot matrix and reduce over the
416 datapoints; whether the ``n``-by-``size`` product is fused into the
417 reduction or materialized is left to the backend. Faster on cpu.
418 'scatter_set'
419 Scatter the values into a dense buffer with unique (non-atomic) writes,
420 then sum over the datapoints.
421 """
423 n_inner: bool = field(static=True, default=True)
424 """Whether the datapoints sit on the one-hot's inner, contiguous axis
425 (``size``-by-``n``) or its outer axis (``n``-by-``size``); the two layouts
426 give the backend different memory access patterns. `True` (the default)
427 fuses better on gpu."""
429 def _reduce(
430 self,
431 values: Float[Array, '*batch_shape n'] | int,
432 indices: UInt[Array, ' n'],
433 /,
434 *,
435 size: int,
436 subset_start: Integer[Array, ''] | None = None,
437 subset_length: int | None = None,
438 dtype: DTypeLike,
439 data_sharded: bool,
440 ) -> Shaped[Array, '*batch_shape {(size,subset_length)[bool(subset_length)]}']:
441 values = jnp.asarray(values)
442 assert values.ndim == 0 or values.shape[-1:] == indices.shape
444 # a scalar value is the count case, weighting each datapoint by `values`;
445 # it broadcasts in the scatter/multiply paths, only matmul needs a vector
446 scalar = values.ndim == 0
447 (n,) = indices.shape
448 batch_shape = values.shape[:-1]
449 size, bins, indices = _resolve_range_bins(
450 indices,
451 size,
452 subset_start,
453 subset_length,
454 remap=self.method == 'scatter_set',
455 )
456 # unsigned avoids a negative-index normalization select in the scatter
457 iota = jnp.arange(n, dtype=jnp.uint32)
459 # one-hots and scatter buffers hold the values, so they are built in
460 # the values' dtype; only the reduction accumulates in `dtype`. The
461 # scalar count case has no input precision to preserve and uses `dtype`.
462 values_dtype = dtype if scalar else values.dtype
464 match self.method, self.n_inner:
465 case 'scatter_set', True:
466 out = (
467 jnp.zeros((*batch_shape, size, n), values_dtype)
468 .at[..., indices, iota]
469 .set(values, unique_indices=True)
470 .sum(axis=-1, dtype=dtype)
471 )
472 case 'scatter_set', False:
473 out = (
474 jnp.zeros((*batch_shape, n, size), values_dtype)
475 .at[..., iota, indices]
476 .set(values, unique_indices=True)
477 .sum(axis=-2, dtype=dtype)
478 )
479 case 'matmul', True:
480 onehot = (bins[:, None] == indices).astype(values_dtype) # (size, n)
481 vec = jnp.broadcast_to(values.astype(dtype), (n,)) if scalar else values
482 out = jnp.einsum(
483 '...n,sn->...s', vec, onehot, preferred_element_type=dtype
484 )
485 case 'matmul', False:
486 onehot = (indices[:, None] == bins).astype(values_dtype) # (n, size)
487 vec = jnp.broadcast_to(values.astype(dtype), (n,)) if scalar else values
488 out = jnp.einsum(
489 '...n,ns->...s', vec, onehot, preferred_element_type=dtype
490 )
491 case 'multiply', True:
492 onehot = bins[:, None] == indices # (size, n)
493 if scalar:
494 out = values * onehot.sum(axis=-1, dtype=dtype)
495 else:
496 out = (values[..., None, :] * onehot).sum(axis=-1, dtype=dtype)
497 case 'multiply', False: 497 ↛ 504line 497 didn't jump to line 504 because the pattern on line 497 always matched
498 onehot = indices[:, None] == bins # (n, size)
499 if scalar:
500 out = values * onehot.sum(axis=-2, dtype=dtype)
501 else:
502 out = (values[..., :, None] * onehot).sum(axis=-2, dtype=dtype)
504 if data_sharded:
505 out = lax.psum(out, 'data')
506 return out
509def _resolve_range_bins(
510 indices: UInt[Array, ' n'],
511 size: int,
512 subset_start: Integer[Array, ''] | None,
513 subset_length: int | None,
514 *,
515 remap: bool,
516) -> tuple[int, UInt[Array, ' out_size'], UInt[Array, ' n']]:
517 """Resolve the range subset into output size, comparison bins, and scatter indices.
519 The comparison methods of `OneHotReduction` reduce against the range's bins
520 directly, while its scatter method (`remap`) indexes bins by position, so
521 the indices are offset like in `_resolve_range`.
522 """
523 if subset_length is None:
524 return size, jnp.arange(size, dtype=indices.dtype), indices
525 assert subset_start is not None # set together with subset_length
526 # uint32, not the possibly narrow `indices.dtype`, so bins past `size` do
527 # not wrap and alias a real bin in the comparison
528 bins = subset_start.astype(jnp.uint32) + jnp.arange(subset_length, dtype=jnp.uint32)
529 if remap:
530 out_size, indices = _resolve_range(indices, size, subset_start, subset_length)
531 return out_size, bins, indices
532 else:
533 return subset_length, bins, indices
536class AutoOneHotReduction(ReductionConfig):
537 """`OneHotReduction` that picks `method` and `n_inner` automatically.
539 Resolves both knobs from trace-time information per site and platform, then
540 delegates to a plain `OneHotReduction`. Uses `matmul` only for wide-bin
541 multivariate reductions and `multiply` otherwise; lays the datapoints on the
542 outer axis except on the two small-bin sites where the opposite wins (cpu
543 precision, cuda count). Those two sites support only cpu and cuda, raising at
544 lowering elsewhere.
546 The site is recovered from the value: scalar is the count, a wide output the
547 residual, a narrow non-scalar output the precision.
549 Known limitation: the wide-bin univariate residual on cpu past ~10^6
550 datapoints prefers a layout this picks against (up to ~2x slower).
551 """
553 min_matmul_bins: int = field(static=True, default=8)
554 """Minimum output bins for `matmul`; below it `multiply` is always used."""
556 def _reduce(
557 self,
558 values: Float[Array, '*batch_shape n'] | int,
559 indices: UInt[Array, ' n'],
560 /,
561 *,
562 size: int,
563 subset_start: Integer[Array, ''] | None = None,
564 subset_length: int | None = None,
565 dtype: DTypeLike,
566 data_sharded: bool,
567 ) -> Shaped[Array, '*batch_shape {(size,subset_length)[bool(subset_length)]}']:
568 out_size = size if subset_length is None else subset_length
569 m = max(1, math.prod(jnp.shape(values)[:-1]))
570 method = 'matmul' if m >= 2 and out_size >= self.min_matmul_bins else 'multiply'
572 if jnp.ndim(values) == 0: # count
573 cpu_inner, cuda_inner = False, True
574 elif out_size <= 2: # precision
575 cpu_inner, cuda_inner = True, False
576 else: # residual
577 cpu_inner, cuda_inner = False, False
579 args = (values, indices)
580 kwargs: dict = dict(
581 size=size,
582 subset_start=subset_start,
583 subset_length=subset_length,
584 dtype=dtype,
585 data_sharded=data_sharded,
586 )
587 if cpu_inner == cuda_inner:
588 # the layout matches on every platform, so no platform split is
589 # needed and the reduction also runs on untested platforms (tpu/rocm)
590 return OneHotReduction(method=method, n_inner=cpu_inner)._reduce( # noqa: SLF001
591 *args, **kwargs
592 )
593 else:
594 # defer the cpu/gpu choice to XLA: both branches are traced, but only
595 # the run platform's is lowered. With no `default`, an untested
596 # platform errors at lowering instead of silently falling back.
597 return lax.platform_dependent(
598 cpu=partial(
599 OneHotReduction(method=method, n_inner=cpu_inner)._reduce, # noqa: SLF001
600 *args,
601 **kwargs,
602 ),
603 cuda=partial(
604 OneHotReduction(method=method, n_inner=cuda_inner)._reduce, # noqa: SLF001
605 *args,
606 **kwargs,
607 ),
608 )
611class PallasReduction(ReductionConfig):
612 """Blocked one-hot scatter-add written as a Pallas kernel.
614 Splits the datapoints into blocks and, for each block, contracts the values
615 against a one-hot leaf-membership matrix held in fast memory, accumulating
616 the block partials. Unlike `OneHotReduction`, the one-hot product is
617 guaranteed to stay fused (it is never written back to main memory). Targets
618 gpu/tpu; on cpu it falls back to Pallas interpret mode, which is slow and
619 meant only for testing. Like `OneHotReduction`, it is competitive only when
620 the number of output bins is small. Does not support sharding the
621 datapoints across devices.
623 On gpu the kernel is lowered through Triton or Mosaic GPU; see `backend`.
624 """
626 block_size: int | Literal['auto'] = field(static=True, default='auto')
627 """Datapoints contracted per kernel iteration, i.e., the width of the one-hot
628 tile in fast memory. If 'auto', chosen to keep that tile small. Should be a
629 power of 2 on gpu."""
631 num_blocks: int | Literal['auto'] = field(static=True, default='auto')
632 """Number of kernel instances (grid size) the datapoints are split across,
633 each looping over its share. More instances raise occupancy but enlarge the
634 partial-sum buffer. If 'auto', resolved per-platform at trace time."""
636 auto_gpu_target: int = field(static=True, default=1024)
637 """Cap on the number of kernel instances on gpu when `num_blocks` is 'auto'."""
639 backend: Literal['triton', 'cpu', 'default'] = field(static=True, default='triton')
640 """How to lower the kernel. The run platform is not known here at trace time,
641 so it cannot be selected automatically:
643 'triton'
644 Pass Triton compiler params; the default, compiles on every CUDA/ROCm gpu.
645 'cpu'
646 Pallas interpret mode, the only mode that runs on cpu (slow; for testing).
647 'default'
648 Pass nothing, leaving jax to pick its own gpu backend: that is Mosaic GPU,
649 which only compiles on Hopper and newer (compute capability 9.0+).
650 """
652 def _reduce(
653 self,
654 values: Float[Array, '*batch_shape n'] | int,
655 indices: UInt[Array, ' n'],
656 /,
657 *,
658 size: int,
659 subset_start: Integer[Array, ''] | None = None,
660 subset_length: int | None = None,
661 dtype: DTypeLike,
662 data_sharded: bool,
663 ) -> Shaped[Array, '*batch_shape {(size,subset_length)[bool(subset_length)]}']:
664 if data_sharded:
665 # the kernel trips the vma checks of `shard_map`, in jax-version-
666 # dependent ways, even in interpret mode
667 msg = 'PallasReduction does not support a sharded data axis'
668 raise NotImplementedError(msg)
670 values = jnp.asarray(values)
671 assert values.ndim == 0 or values.shape[-1:] == indices.shape
672 (n,) = indices.shape
673 num_rows = 1 if values.ndim == 0 else math.prod(values.shape[:-1])
674 # the kernel compares the indices against the range's bins, so the
675 # subset case costs nothing more than the full one (`bins=None`)
676 if subset_length is None:
677 out_size = size
678 bins = None
679 else:
680 assert subset_start is not None # set together with subset_length
681 out_size = subset_length
682 # uint32 so bins past `size` do not wrap and alias a real bin
683 bins = subset_start.astype(jnp.uint32) + jnp.arange(
684 subset_length, dtype=jnp.uint32
685 )
687 # the grid size, tile width and interpret flag are all static; they are
688 # resolved from `backend` (cpu vs gpu) rather than the trace-time platform,
689 # which need not match the run platform
690 interpret = self.backend == 'cpu'
691 compiler_params = _resolve_pallas_backend(self.backend)
692 if self.block_size == 'auto':
693 block_size = _auto_block_size(n, out_size, num_rows)
694 else:
695 block_size = self.block_size
696 if self.num_blocks == 'auto':
697 target = _AUTO_CPU_TARGET if interpret else self.auto_gpu_target
698 num_blocks = max(1, min(-(-n // block_size), target))
699 else:
700 num_blocks = self.num_blocks
702 return _pallas_scatter_add(
703 values,
704 indices,
705 size=size,
706 bins=bins,
707 dtype=dtype,
708 num_blocks=num_blocks,
709 block_size=block_size,
710 interpret=interpret,
711 compiler_params=compiler_params,
712 )
715def _resolve_pallas_backend(
716 backend: Literal['triton', 'cpu', 'default'],
717) -> pl.CompilerParams | None:
718 """`compiler_params` for `pallas_call`, or `None` to use its own default.
720 Only 'triton' passes params (it compiles on every CUDA/ROCm gpu); 'cpu'
721 (interpret mode) and 'default' (Mosaic GPU, Hopper-only) pass nothing.
722 """
723 if backend != 'triton':
724 return None
725 else:
726 from jax.experimental.pallas import triton as pallas_triton # noqa: PLC0415
728 return pallas_triton.CompilerParams()
731def _ceil_pow2(n: int) -> int:
732 """Smallest power of 2 >= `n`."""
733 return 1 << max(0, n - 1).bit_length()
736def _auto_block_size(n: int, size: int, num_rows: int) -> int:
737 """Power-of-2 datapoint tile keeping the kernel's one-hot working set small."""
738 area = max(1, _AUTO_PALLAS_TILE // (size * num_rows))
739 block_size = 1 << round(math.log2(area)) # nearest power of 2
740 return min(block_size, _ceil_pow2(n)) # avoid padding beyond `n`
743def _pallas_scatter_add(
744 values: Float[Array, '*batch_shape n'] | Int32[Array, ''],
745 indices: UInt[Array, ' n'],
746 /,
747 *,
748 size: int,
749 bins: Integer[Array, ' sub_size'] | None,
750 dtype: DTypeLike,
751 num_blocks: int,
752 block_size: int,
753 interpret: bool,
754 compiler_params: pl.CompilerParams | None = None,
755) -> Shaped[Array, '*batch_shape {getattr(bins,"size",size)}']:
756 """Blocked one-hot indexed reduce via a Pallas kernel; see `PallasReduction`.
758 `bins` are the bin indices to reduce into, `None` meaning all of ``0, 1,
759 ..., size - 1``; the kernel compares the indices against them, so the
760 subset case costs nothing more than the full one.
761 """
762 scalar = values.ndim == 0
763 (n,) = indices.shape
764 if scalar:
765 # a scalar value (the count case) weights every datapoint equally
766 batch_shape = ()
767 rows = jnp.broadcast_to(values.astype(dtype), (1, n))
768 else:
769 # the rows stay in the values' dtype; the kernel accumulates in `dtype`
770 batch_shape = values.shape[:-1]
771 rows = values.reshape(-1, n)
772 num_rows, _ = rows.shape
773 if bins is None:
774 out_size = size
775 else:
776 (out_size,) = bins.shape
778 # the Triton backend requires every array dimension to be a power of 2;
779 # `block_size` already is. Pad the rows axis (its length is the product of
780 # the value's batch shape, e.g. k or k*k, so any k) with zero rows and the
781 # bins axis (the full `size` is a power of 2, but a range subset has any
782 # length) with the out-of-domain bin `size`; both pads are sliced off below.
783 # `bins` is uint32, so `size` cannot wrap and alias a real bin.
784 padded_rows = _ceil_pow2(num_rows)
785 padded_size = _ceil_pow2(out_size)
786 if bins is not None:
787 bins = jnp.pad(bins, (0, padded_size - out_size), constant_values=size)
789 # each instance scans `iters` sub-blocks of `block_size` datapoints; pad the
790 # datapoint axis so it splits evenly. The padded datapoints are zero in every
791 # row, so they contribute nothing whatever bin they fall in; the out-of-range
792 # `size` index is just a tidy default (it may wrap in a narrow index dtype).
793 iters = -(-n // (num_blocks * block_size))
794 chunk = iters * block_size
795 pad = num_blocks * chunk - n
796 indices = jnp.pad(indices, (0, pad), constant_values=size)
797 rows = jnp.pad(rows, ((0, padded_rows - num_rows), (0, pad)))
799 # the kernel operates on `Ref`s, not arrays, so it carries no array
800 # annotations (which would also trip runtime shape typechecking)
801 def kernel(rows_ref, indices_ref, *refs): # noqa: ANN001, ANN002, ANN202
802 if bins is None:
803 (out_ref,) = refs
804 kernel_bins = jnp.arange(padded_size, dtype=indices_ref.dtype)
805 else:
806 bins_ref, out_ref = refs
807 kernel_bins = bins_ref[:]
809 def accumulate(i, acc): # noqa: ANN001, ANN202
810 block = pl.ds(i * block_size, block_size)
811 onehot = kernel_bins[:, None] == indices_ref[block]
812 # the one-hot product is exact in the values' dtype; the block
813 # reduction accumulates in `dtype`
814 prod = rows_ref[:, block][:, None, :] * onehot.astype(rows_ref.dtype)
815 return acc + prod.sum(axis=-1, dtype=dtype)
817 out_ref[0] = lax.fori_loop(
818 0, iters, accumulate, jnp.zeros((padded_rows, padded_size), dtype)
819 )
821 in_specs = [
822 pl.BlockSpec((padded_rows, chunk), lambda p: (0, p)),
823 pl.BlockSpec((chunk,), lambda p: (p,)),
824 ]
825 args = [rows, indices]
826 if bins is not None:
827 # every instance reads the whole (tiny) bins array
828 in_specs.append(pl.BlockSpec((padded_size,), lambda _p: (0,)))
829 args.append(bins)
831 out = pl.pallas_call(
832 kernel,
833 out_shape=ShapeDtypeStruct((num_blocks, padded_rows, padded_size), dtype),
834 grid=(num_blocks,),
835 in_specs=in_specs,
836 out_specs=pl.BlockSpec((1, padded_rows, padded_size), lambda p: (p, 0, 0)),
837 interpret=interpret,
838 compiler_params=compiler_params,
839 name='scatter_add',
840 )(*args)
841 # drop the power-of-2 padding rows and bins
842 out = out.sum(axis=0)[:num_rows, :out_size]
844 if scalar:
845 return out.reshape(out_size)
846 return out.reshape(*batch_shape, out_size)