Source code for heterodyne.optimization.cmc.backends.pjit

"""Multi-device parallel MCMC backend using JAX sharding.

Distributes NUTS chains across multiple CPU devices using JAX's modern
sharding API (``jax.sharding``), available since JAX 0.4.1.  This
replaces the deprecated ``jax.experimental.pjit`` with the stable
``jax.jit`` + sharding annotation approach.  Heterodyne is CPU-only.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import jax
import jax.numpy as jnp
import numpy as np
from numpyro.infer import MCMC, NUTS
from numpyro.infer import initialization as numpyro_init

from heterodyne.optimization.cmc.backends.base import (
    _NUTS_EXTRA_FIELDS,
    BackendCapabilities,
    CMCBackend,
)
from heterodyne.utils.logging import get_logger

if TYPE_CHECKING:
    from collections.abc import Callable

    from heterodyne.optimization.cmc.config import CMCConfig

logger = get_logger(__name__)

# Map config string names to NumPyro initialization factories
_INIT_STRATEGY_MAP: dict[str, Callable[..., Any]] = {
    "init_to_median": numpyro_init.init_to_median,
    "init_to_sample": numpyro_init.init_to_sample,
    "init_to_value": numpyro_init.init_to_value,
}


def _slice_init_params(
    init_params: dict[str, jnp.ndarray] | None,
    start: int,
    count: int,
    n_chains_total: int,
) -> dict[str, jnp.ndarray] | None:
    """Return per-device init slice for the chain block ``[start, start+count)``.

    Leaves whose leading axis equals ``n_chains_total`` are sliced along
    axis 0; everything else (scalars, broadcast inits, init_to_median
    state) is forwarded unchanged so common cases keep working.
    """
    if init_params is None or count == n_chains_total:
        return init_params
    sliced: dict[str, jnp.ndarray] = {}
    for name, leaf in init_params.items():
        if hasattr(leaf, "shape") and leaf.shape and leaf.shape[0] == n_chains_total:
            sliced[name] = leaf[start : start + count]
        else:
            sliced[name] = leaf
    return sliced


[docs] class PjitBackend(CMCBackend): """Multi-device parallel MCMC backend using JAX sharding. Distributes NUTS chains across all available JAX devices. Each device runs a subset of the requested chains in parallel via NumPyro's vectorized chain execution. When only a single device is available, this backend transparently falls back to running all chains on that device (equivalent to the single-device ``chain_method="parallel"``). This backend uses the modern ``jax.sharding`` API (stable since JAX 0.4.1), not the deprecated ``jax.experimental.pjit``. """
[docs] def run( self, model: Callable[..., Any], config: CMCConfig, rng_key: jnp.ndarray, init_params: dict[str, jnp.ndarray] | None = None, ) -> dict[str, Any]: """Run NUTS sampling distributed across multiple devices. Splits chains across available devices. Each device group runs independently, and results are gathered and concatenated. Args: model: NumPyro model function. config: CMC configuration with sampling hyperparameters. rng_key: JAX PRNG key for reproducibility. init_params: Optional per-chain initial parameter values. Returns: Dictionary mapping parameter names to flat sample arrays. Raises: RuntimeError: If sampling fails on any device. """ devices = jax.devices() n_devices = len(devices) n_chains = config.num_chains logger.info( "PjitBackend: distributing %d chains across %d device(s) " "(%d warmup, %d samples each)", n_chains, n_devices, config.num_warmup, config.num_samples, ) if n_chains < 1: raise ValueError(f"PjitBackend: num_chains must be >= 1, got {n_chains}") if n_devices < 1: raise RuntimeError("PjitBackend: no JAX devices available") # Chains-per-device distribution. The previous ``max(1, n_chains // # n_devices)`` produced spurious extra chains when n_chains < n_devices # (e.g. 2 chains on 8 devices ran 8 chains instead of 2). Allocate # base chains by floor division and spread the remainder over the # first ``remainder`` devices; devices with zero chains are skipped. chains_per_device = n_chains // n_devices remainder = n_chains % n_devices init_fn = _INIT_STRATEGY_MAP.get( config.init_strategy, numpyro_init.init_to_median ) # Split RNG keys for each device group rng_keys = jax.random.split(rng_key, n_devices) shard_results: list[dict[str, Any]] = [] chain_cursor = 0 # tracks position in init_params chain axis for device_idx in range(n_devices): # First ``remainder`` devices pick up one extra chain device_chains = chains_per_device + (1 if device_idx < remainder else 0) if device_chains == 0: continue device = devices[device_idx] device_rng = rng_keys[device_idx] # Slice init_params for this device's chain block if the user # supplied a chain-shaped init (leading axis == n_chains). Leaf- # broadcasted inits (no chain axis, or singleton leading dim) are # forwarded unchanged so init_to_median / scalar inits still work. device_init = _slice_init_params( init_params, chain_cursor, device_chains, n_chains ) chain_cursor += device_chains logger.debug( "PjitBackend: device %d (%s) running %d chain(s) " "[chains %d..%d of %d total]", device_idx, device.platform, device_chains, chain_cursor - device_chains, chain_cursor - 1, n_chains, ) # Place computation on specific device with jax.default_device(device): kernel = NUTS( model, target_accept_prob=config.target_accept_prob, max_tree_depth=config.max_tree_depth, dense_mass=config.dense_mass, init_strategy=init_fn(), ) mcmc = MCMC( kernel, num_warmup=config.num_warmup, num_samples=config.num_samples, num_chains=device_chains, chain_method="parallel" if device_chains > 1 else "sequential", progress_bar=(device_idx == 0), # Only show for first device ) mcmc.run( device_rng, init_params=device_init, extra_fields=_NUTS_EXTRA_FIELDS, ) shard_samples = mcmc.get_samples() shard_results.append(dict(shard_samples)) # Combine results from all shards combined = combine_shard_samples(shard_results) logger.info( "PjitBackend: sampling complete — %d total samples across %d device(s)", config.num_samples * n_chains, n_devices, ) return combined
[docs] def get_capabilities(self) -> BackendCapabilities: """Return capabilities for multi-device parallel execution. Returns: BackendCapabilities with sharding support flags. """ devices = jax.devices() return BackendCapabilities( supports_sharding=True, supports_parallel_chains=True, max_parallel_shards=len(devices), )
[docs] def validate_resources(self) -> None: """Check that multiple devices are available for sharding. Logs a warning (but does not raise) when only one device is detected, since the backend can still function in single-device mode. Raises: RuntimeError: If no JAX devices are available at all. """ devices = jax.devices() if not devices: raise RuntimeError("PjitBackend: no JAX devices available") if len(devices) == 1: logger.warning( "PjitBackend: only 1 device detected (%s); " "sharding will not provide parallelism. " "Consider using CPUBackend instead.", devices[0].platform, ) else: logger.info( "PjitBackend: %d devices available (%s)", len(devices), ", ".join(f"{d.platform}:{d.id}" for d in devices), )
[docs] def estimate_memory( self, n_data: int, n_params: int, n_chains: int, ) -> float: """Estimate peak memory per device in GB. Each device holds a fraction of the chains. Memory per chain is approximately: n_params x n_data x 8 bytes (float64) for the likelihood evaluation, plus sample storage. Args: n_data: Number of data points per shard. n_params: Number of model parameters. n_chains: Total number of MCMC chains. Returns: Estimated peak memory in GB per device. """ n_devices = max(len(jax.devices()), 1) chains_per_device = max(1, n_chains // n_devices) # Per-chain memory: model evaluation + gradient + samples bytes_per_chain = ( n_data * n_params * 8 # Jacobian-like + n_params * n_params * 8 # Mass matrix + n_data * 8 # Residuals ) total_bytes = bytes_per_chain * chains_per_device return total_bytes / (1024**3)
[docs] def cleanup(self) -> None: """Release resources. No-op for JAX-managed devices.""" logger.debug("PjitBackend: cleanup (no-op for JAX-managed devices)")
[docs] def combine_shard_samples( shard_results: list[dict[str, Any]], ) -> dict[str, Any]: """Combine posterior samples from multiple device shards. Concatenates sample arrays along the first axis (samples dimension). Args: shard_results: List of sample dictionaries, one per device shard. Each dict maps parameter names to numpy/JAX arrays. Returns: Combined dictionary with concatenated samples. Raises: ValueError: If shard_results is empty. """ if not shard_results: raise ValueError("combine_shard_samples requires at least 1 shard result") if len(shard_results) == 1: return shard_results[0] # Get parameter names from first shard param_names = list(shard_results[0].keys()) combined: dict[str, Any] = {} for name in param_names: arrays = [np.asarray(sr[name]) for sr in shard_results if name in sr] if arrays: combined[name] = np.concatenate(arrays, axis=0) logger.debug( "combine_shard_samples: combined %d shards → %d parameters", len(shard_results), len(combined), ) return combined