ADR-004: Consensus Monte Carlo for Bayesian Inference¶
- Status:
Accepted
- Date:
2026-05-19
- Deciders:
Core team
Context¶
The heterodyne 14-parameter model posterior
is evaluated over thousands of \((t_1, t_2)\) data points per shard. Standard MCMC (single-chain NUTS) requires O(N) likelihood evaluations per leapfrog step, where N is the total number of time-pair observations. For a 1000-frame dataset, N ≈ 500K pairs; a single NUTS step costs ~50 ms on a 32-core workstation. Convergence requires ~2000 warmup + 1000 production steps: total ~2.5 hours for a single chain.
Users need posterior samples for all 14 physical parameters plus per-angle scaling within 30–90 minutes on typical beamline hardware (32–64 cores).
Decision¶
Heterodyne implements Consensus Monte Carlo (CMC, Scott et al. 2016) as its Bayesian backend. The algorithm is:
Shard: Partition the \(n\) data points into \(M\) shards of size \(n_s \approx n/M\).
Parallel NUTS: Run independent NUTS chains on each shard (in separate worker processes), obtaining \(K\) posterior samples \(\{\theta^{(k)}_s\}\) from the shard-specific posterior \(p_s(\theta\,|\,\mathcal{D}_s)\).
Consensus: Combine the \(M\) sets of shard samples into a single approximation of the global posterior \(p(\theta\,|\,\mathcal{D})\).
The multiprocessing backend spawns \(N_\mathrm{workers} = \lfloor N_\mathrm{cores}/2\rfloor - 1\)
worker processes, each with 4 virtual JAX devices (via
--xla_force_host_platform_device_count=4). Each worker runs NUTS in parallel mode
(pmap over 4 devices), achieving near-full CPU utilization.
NUTS Warmup Floor¶
num_warmup defaults to 1500 because dense_mass=True on the 14-parameter
heterodyne model requires ≥ 100 mass-matrix adaptation steps per dimension. Lowering
this without simultaneously reducing parameter count or disabling dense_mass will
produce high R-hat and divergence storms. This floor is enforced by tests.
JIT Cache in Workers¶
In JAX 0.8+, the JAX_COMPILATION_CACHE_DIR environment variable alone does not
enable the persistent cache. Workers must call:
jax.config.update("jax_compilation_cache_dir", cache_dir)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
This is done in the worker initialization function in
optimization/cmc/backends/multiprocessing.py.
Environment Isolation¶
Before spawning workers, the backend:
Saves the current environment (
OMP_PROC_BIND,OMP_PLACES).Clears
OMP_PROC_BINDandOMP_PLACESto prevent OpenMP thread binding conflicts between workers.Sets
OMP_NUM_THREADS=1or 2 per worker to prevent thread oversubscription (each worker manages its own JAX device count via XLA_FLAGS).Restores the parent environment after all workers are spawned.
Rationale¶
Why CMC over standard MCMC?
Standard single-chain NUTS on all N data points takes O(hours) per chain, without
parallelism. CMC decomposes the problem: each shard runs NUTS independently (trivially
parallel), and the consensus step is a simple weighted average of the shard posteriors.
The total wall time is approximately T_shard / N_workers, where T_shard is the
time to run NUTS on a single shard.
Why NUTS over Metropolis-Hastings?
The 14-parameter heterodyne model has strong posterior correlations (e.g., \(D_{0,\mathrm{ref}}\) and \(D_{0,\mathrm{sample}}\)). Metropolis-Hastings requires hand-tuned proposal distributions for each parameter pair, which is impractical for production use. NUTS adapts its step size and trajectory length automatically via dual averaging.
Why dense mass matrix?
The 14-parameter space has off-diagonal correlations that would cause NUTS to take very small steps with a diagonal mass matrix. The dense mass matrix (full Hessian approximation) enables NUTS to take larger, correlated steps, reducing the number of leapfrog steps needed for convergence.
Why multiprocessing over threading?
JAX uses the Global Interpreter Lock (GIL) at the Python level; threading would
serialize the NUTS leapfrog steps. Multiprocessing (spawn mode) gives each worker
a fresh Python interpreter with its own JAX device context and XLA flags.
Consequences¶
Positive:
Wall time scales as
T_shard / N_workers— near-linear speedup with core count.Each shard posterior is independently diagnosed (R-hat, ESS, divergences per shard).
Failed shards can be retried or skipped without re-running all shards.
The consensus posterior approximates the full posterior to the same degree regardless of the number of shards (Scott et al. 2016 guarantee).
Negative / Accepted trade-offs:
CMC is an approximation: the consensus posterior is not exactly equal to the full posterior. The approximation is better as shard size increases.
Worker spawning overhead (~2–5 s per worker) adds latency for small datasets.
The multiprocessing backend requires
spawnmode (notfork), which re-imports all modules in each worker. Heavy imports (JAX, NumPyro) add ~3 s per worker startup.ArviZ diagnostics must aggregate across shards; the
summarize_diagnostics()function handles this but adds code complexity.
Alternatives Considered¶
A. Standard NUTS on full data
Correct by construction. Rejected because: too slow for practical use (> 8 hours on typical hardware for the 14-parameter heterodyne model with a 1000-frame dataset).
B. Variational inference (ADVI)
ADVI in NumPyro produces approximate posteriors much faster than NUTS. Rejected because: mean-field ADVI produces overconfident posteriors for correlated parameter spaces; the full-rank ADVI covariance requires O(d²) parameters (196 for 14 dims), which is slow to optimize and numerically unstable.
C. Parallel tempering (Pigeons-style)
Non-reversible parallel tempering (NRPT) is superior to CMC for multimodal posteriors. Rejected for the default backend because: the heterodyne posterior is expected to be unimodal after NLSQ warm-start; the computational cost of the swap moves between temperature levels adds overhead. May be reconsidered if multimodal posteriors are observed in practice.
D. Ensemble sampler (emcee-style)
Affine-invariant ensemble samplers handle correlated posteriors without tuning. Rejected because: ensemble samplers are not trivially parallelizable across data shards, and the CMC consensus framework is not directly applicable to ensemble states.
See also
ADR-001: JAX CPU-Only Backend — CPU-only JAX decision
ADR-002: NLSQ / CMC Architectural Split — NLSQ / CMC architectural split
Computational Methods — CMC mathematical details