bartz.mcmcstep.OneHotReduction

class bartz.mcmcstep.OneHotReduction(method='matmul', n_inner=True)[source]

Dense one-hot reduction.

Materializes the membership of each datapoint in its leaf as a one-hot matrix over the output bins and contracts it against the values. Beats BatchedReduction only when the number of bins is very small (e.g. a single leaf pair), or on gpu for multivariate residuals.

method: Literal['matmul', 'multiply', 'scatter_set'] = 'matmul'

How to contract the values against the one-hot leaf-membership matrix:

‘matmul’

Contract the values with the one-hot matrix via a dot. Faster on gpu, especially for multivariate residuals.

‘multiply’

Elementwise-multiply by the one-hot matrix and reduce over the datapoints; whether the n-by-size product is fused into the reduction or materialized is left to the backend. Faster on cpu.

‘scatter_set’

Scatter the values into a dense buffer with unique (non-atomic) writes, then sum over the datapoints.

n_inner: bool = True

Whether the datapoints sit on the one-hot’s inner, contiguous axis (size-by-n) or its outer axis (n-by-size); the two layouts give the backend different memory access patterns. True (the default) fuses better on gpu.