"""High-level NUTS sampler wrapper for heterodyne CMC analysis.
Provides a ``SamplingPlan`` dataclass for sampling hyperparameters and a
``NUTSSampler`` class that wraps NumPyro's MCMC with ergonomic factories,
automatic chain initialization with perturbation, and ArviZ diagnostics.
"""
from __future__ import annotations
import math
import secrets
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import arviz as az
import jax
import jax.numpy as jnp
from numpyro.infer import MCMC, NUTS
from numpyro.infer import initialization as numpyro_init
from heterodyne.optimization.cmc.config import effective_warmup_floor
from heterodyne.utils.logging import get_logger
if TYPE_CHECKING:
from collections.abc import Callable
from heterodyne.optimization.cmc.config import CMCConfig
logger = get_logger(__name__)
#: Below this shard size, ``SamplingPlan.for_shard`` downgrades a ``"parallel"``
#: chain method to ``"sequential"`` so JIT and warmup costs are shared across
#: chains. Mirrors the homodyne small-shard policy.
_SMALL_SHARD_CHAIN_THRESHOLD: int = 100
# ---------------------------------------------------------------------------
# Adapter state diagnostics — homodyne CMC parity
# ---------------------------------------------------------------------------
#
# These helpers introspect the NumPyro ``last_state`` / ``adapt_state``
# objects so callers can log the adapted step size, inverse mass matrix,
# and accept-prob trajectory at the end of a run. They are robust to
# the multiple shapes NumPyro uses (scalar vs per-chain dense matrix vs
# dict-of-arrays) and never raise.
def _summarize_inverse_mass_matrix(inv_mass: Any) -> str:
"""Return a compact textual summary of the adapted inverse mass matrix.
Handles scalar, 1-D diagonal, 2-D dense, ``(n_chains, dim, dim)``
per-chain dense, dict-of-arrays, and list/tuple-of-per-chain forms.
Reports diagonal min/max, condition number, and dimensionality so
callers can flag mass-matrix pathologies (near-singular, exploding
condition numbers) at glance.
"""
import numpy as _np
def _one(mat: Any) -> str:
if isinstance(mat, dict):
keys = list(mat.keys())
if not keys:
return "dict(empty)"
first = mat[keys[0]]
return f"dict(keys={len(keys)}) first[{keys[0]}]: {_one(first)}"
try:
arr = _np.asarray(mat)
except Exception: # noqa: BLE001
return f"type={type(mat).__name__}"
if arr.ndim == 0:
try:
return f"scalar={float(arr):.3g}"
except Exception: # noqa: BLE001
return f"scalar(type={type(arr.item()).__name__})"
if arr.ndim == 1:
diag = arr[_np.isfinite(arr)]
if diag.size == 0:
return f"diag(dim={arr.size}) all-nonfinite"
dmin = float(_np.min(diag))
dmax = float(_np.max(diag))
cond = float(dmax / dmin) if dmin > 0 else float("inf")
return f"diag(dim={arr.size}) min={dmin:.3g} max={dmax:.3g} cond~{cond:.3g}"
if arr.ndim == 2 and arr.shape[0] == arr.shape[1]:
diag = _np.diag(arr)
diag = diag[_np.isfinite(diag)]
if diag.size == 0:
return f"dense(dim={arr.shape[0]}) diag all-nonfinite"
dmin = float(_np.min(diag))
dmax = float(_np.max(diag))
try:
cond = float(_np.linalg.cond(arr))
except Exception: # noqa: BLE001
cond = float("nan")
return (
f"dense(dim={arr.shape[0]}) diag[min={dmin:.3g}, max={dmax:.3g}] "
f"cond={cond:.3g}"
)
if arr.ndim == 3 and arr.shape[1] == arr.shape[2]:
n_chains = arr.shape[0]
dim = arr.shape[1]
parts = [_one(arr[i]) for i in range(min(n_chains, 2))]
more = "" if n_chains <= 2 else f" (+{n_chains - 2} more)"
return f"per-chain dense(dim={dim})[{', '.join(parts)}]{more}"
return f"array(shape={arr.shape}, ndim={arr.ndim})"
if isinstance(inv_mass, list | tuple):
parts = [_one(m) for m in inv_mass[:2]]
more = "" if len(inv_mass) <= 2 else f" (+{len(inv_mass) - 2} more)"
return f"per-chain[{', '.join(parts)}]{more}"
return _one(inv_mass)
def _extract_adapt_states(last_state: Any) -> list[Any]:
"""Return a list of NumPyro per-chain ``adapt_state`` objects.
Empty when ``last_state`` is ``None`` or the adapt_state attribute is
not present (e.g. when NUTS adaptation never ran).
"""
if last_state is None:
return []
if hasattr(last_state, "adapt_state"):
return [last_state.adapt_state]
if isinstance(last_state, list | tuple):
return [item.adapt_state for item in last_state if hasattr(item, "adapt_state")]
return []
def _extract_step_sizes(adapt_states: list[Any]) -> list[float]:
"""Pull the final adapted ``step_size`` out of each adapt_state."""
step_sizes: list[float] = []
for adapt_state in adapt_states:
if adapt_state is None:
continue
if hasattr(adapt_state, "step_size"):
try:
step_sizes.append(float(adapt_state.step_size))
continue
except Exception: # noqa: S110 — robust fallback for adapt_state variants
pass
if isinstance(adapt_state, dict) and "step_size" in adapt_state:
try:
step_sizes.append(float(adapt_state["step_size"]))
except Exception: # noqa: S110 — robust fallback
pass
return step_sizes
def _log_array_stats(run_logger: Any, *, name: str, arr: Any) -> None:
"""Emit a single-line stat summary for an MCMC extra-field array."""
import numpy as _np
try:
a = _np.asarray(arr)
except Exception: # noqa: BLE001
return
if a.size == 0:
return
finite = _np.isfinite(a)
if not _np.any(finite):
run_logger.info(f"{name} stats: all non-finite, shape={a.shape}")
return
run_logger.info(
f"{name} stats: "
f"min={float(_np.min(a[finite])):.3g}, "
f"median={float(_np.median(a[finite])):.3g}, "
f"max={float(_np.max(a[finite])):.3g}, "
f"mean={float(_np.mean(a[finite])):.3g}, "
f"std={float(_np.std(a[finite])):.3g}, "
f"finite={float(_np.mean(finite)):.1%}, shape={a.shape}"
)
[docs]
@dataclass(frozen=True)
class SamplingPlan:
"""Hyperparameters for NUTS sampling.
Immutable configuration that fully specifies a sampling run.
Attributes:
num_warmup: Number of warmup (adaptation) steps per chain.
num_samples: Number of posterior draws per chain after warmup.
num_chains: Number of independent MCMC chains.
target_accept: Target acceptance probability for dual-averaging
step-size adaptation. Values in [0.6, 0.95] are typical.
max_tree_depth: Maximum binary tree depth for NUTS. Higher values
allow longer trajectories but increase per-step cost.
adapt_step_size: Whether to use dual-averaging step-size adaptation
during warmup.
dense_mass: Whether to estimate a dense (full) mass matrix during
warmup, or use a diagonal approximation.
seed: Explicit random seed for reproducibility. If ``None``, a
cryptographically random seed is generated.
"""
num_warmup: int = 500
num_samples: int = 1000
num_chains: int = 4
target_accept: float = 0.8
max_tree_depth: int = 10
adapt_step_size: bool = True
dense_mass: bool = True
chain_method: str = "sequential"
seed: int | None = None
#: Rule 12 escape hatch propagated from :class:`CMCConfig.fast_warmup`.
#: When True, ``for_shard`` and ``AdaptiveSamplingPlan`` skip the dense-mass
#: warmup floor. CI / pytest fast-mode only — not for production posteriors.
fast_warmup: bool = False
[docs]
def __post_init__(self) -> None:
"""Validate hyperparameters."""
if self.num_warmup < 1:
raise ValueError(f"num_warmup must be >= 1, got {self.num_warmup}")
if self.num_samples < 1:
raise ValueError(f"num_samples must be >= 1, got {self.num_samples}")
if self.num_chains < 1:
raise ValueError(f"num_chains must be >= 1, got {self.num_chains}")
if not (0.1 <= self.target_accept <= 0.99):
raise ValueError(
f"target_accept must be in [0.1, 0.99], got {self.target_accept}"
)
if self.max_tree_depth < 1:
raise ValueError(f"max_tree_depth must be >= 1, got {self.max_tree_depth}")
if self.chain_method not in ("sequential", "parallel", "vectorized"):
raise ValueError(
f"chain_method must be 'sequential', 'parallel', or 'vectorized', "
f"got {self.chain_method!r}"
)
@property
def effective_seed(self) -> int:
"""Return the seed, generating one if not explicitly set."""
if self.seed is not None:
return self.seed
return secrets.randbelow(2**31)
[docs]
@classmethod
def from_config(
cls,
config: CMCConfig,
n_data: int | None = None,
n_params: int | None = None,
) -> SamplingPlan:
"""Build a ``SamplingPlan`` from a :class:`CMCConfig`.
Applies adaptive scaling when ``config.adaptive_sampling`` is ``True``
and ``n_data`` is provided: warmup and sample counts are scaled
proportionally to the ratio ``n_data / _REFERENCE_SHARD_SIZE`` and
clamped to the configured floors.
Args:
config: CMC configuration carrying all NUTS hyperparameters and
adaptive-sampling knobs.
n_data: Number of data points in this shard (or the full dataset
when sharding is disabled). When ``None``, no adaptive
scaling is applied regardless of ``config.adaptive_sampling``.
n_params: Number of varying model parameters. Reserved for
future dimension-aware scaling; currently unused.
Returns:
Fully validated :class:`SamplingPlan`.
"""
# Reference shard size used for proportional scaling.
_REFERENCE_SHARD_SIZE = 10_000
num_warmup = config.num_warmup
num_samples = config.num_samples
if config.adaptive_sampling and n_data is not None and n_data > 0:
scale = min(1.0, n_data / _REFERENCE_SHARD_SIZE)
# Use sqrt scaling: smaller shards need proportionally less warmup
# but the relationship is sub-linear because MCMC mixing time does
# not scale as badly as the raw data ratio suggests.
sqrt_scale = math.sqrt(scale)
num_warmup = max(config.min_warmup, int(config.num_warmup * sqrt_scale))
num_samples = max(config.min_samples, int(config.num_samples * sqrt_scale))
logger.debug(
"SamplingPlan.from_config: n_data=%d, scale=%.3f, "
"num_warmup=%d, num_samples=%d",
n_data,
sqrt_scale,
num_warmup,
num_samples,
)
# Rule 12: dense-mass NUTS requires the configured warmup floor.
num_warmup = effective_warmup_floor(
num_warmup,
dense_mass=config.dense_mass,
fast_warmup=getattr(config, "fast_warmup", False),
)
return cls(
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=config.num_chains,
target_accept=config.target_accept_prob,
max_tree_depth=config.max_tree_depth,
adapt_step_size=True,
dense_mass=config.dense_mass,
chain_method=config.chain_method,
seed=config.seed,
fast_warmup=getattr(config, "fast_warmup", False),
)
[docs]
def for_shard(self, shard_size: int, full_size: int) -> SamplingPlan:
"""Return a scaled-down plan appropriate for a single CMC shard.
Scales warmup and sample counts by ``sqrt(shard_size / full_size)``
to reflect the reduced information content of the shard. Counts are
clamped to a minimum of ``max(1, num_x // 10)`` to avoid degenerate
one-step runs.
Args:
shard_size: Number of data points in this shard.
full_size: Total number of data points across all shards.
Returns:
New :class:`SamplingPlan` with adjusted warmup/sample counts and
the same seed and other hyperparameters.
Raises:
ValueError: If ``shard_size <= 0`` or ``full_size <= 0``.
"""
if shard_size <= 0:
raise ValueError(f"shard_size must be > 0, got {shard_size}")
if full_size <= 0:
raise ValueError(f"full_size must be > 0, got {full_size}")
ratio = min(1.0, shard_size / full_size)
scale = math.sqrt(ratio)
min_warmup = max(1, self.num_warmup // 10)
min_samples = max(1, self.num_samples // 10)
new_warmup = max(min_warmup, int(self.num_warmup * scale))
new_samples = max(min_samples, int(self.num_samples * scale))
# Rule 12: per-shard scaling must not slip below the dense-mass floor.
# Honour ``self.fast_warmup`` so CI fast-mode propagated from
# CMCConfig.fast_warmup keeps its opt-out.
new_warmup = effective_warmup_floor(
new_warmup, dense_mass=self.dense_mass, fast_warmup=self.fast_warmup
)
# Homodyne CMC parity: shards with very few points cannot amortise
# the parallel-chain dispatch overhead. Fall back to sequential
# chains so each shard's chains share JIT state and warmup costs.
effective_chain_method = self.chain_method
if (
shard_size < _SMALL_SHARD_CHAIN_THRESHOLD
and self.chain_method == "parallel"
):
logger.info(
"SamplingPlan.for_shard: shard_size=%d < %d; switching chain_method "
"from 'parallel' to 'sequential' for amortised JIT cost.",
shard_size,
_SMALL_SHARD_CHAIN_THRESHOLD,
)
effective_chain_method = "sequential"
logger.debug(
"SamplingPlan.for_shard: shard_size=%d, full_size=%d, scale=%.3f, "
"num_warmup=%d->%d, num_samples=%d->%d, chain_method=%s",
shard_size,
full_size,
scale,
self.num_warmup,
new_warmup,
self.num_samples,
new_samples,
effective_chain_method,
)
# frozen dataclass — use object.__setattr__ via a new instance
return SamplingPlan(
num_warmup=new_warmup,
num_samples=new_samples,
num_chains=self.num_chains,
target_accept=self.target_accept,
max_tree_depth=self.max_tree_depth,
adapt_step_size=self.adapt_step_size,
dense_mass=self.dense_mass,
chain_method=effective_chain_method,
seed=self.seed,
fast_warmup=self.fast_warmup,
)
[docs]
class NUTSSampler:
"""High-level NUTS sampler wrapping NumPyro's MCMC.
Manages kernel construction, chain initialization with perturbation,
sampling execution, and ArviZ diagnostic extraction.
Use the :meth:`from_plan` factory for the standard construction path.
"""
[docs]
def __init__(
self,
mcmc: MCMC,
plan: SamplingPlan,
) -> None:
self._mcmc = mcmc
self._plan = plan
self._has_run = False
@property
def plan(self) -> SamplingPlan:
"""The sampling plan used to configure this sampler."""
return self._plan
[docs]
@classmethod
def from_plan(
cls,
plan: SamplingPlan,
model: Callable[..., Any],
init_strategy: str = "init_to_median",
chain_method: str | None = None,
) -> NUTSSampler:
"""Create a NUTSSampler from a SamplingPlan and NumPyro model.
Args:
plan: Sampling hyperparameters.
model: NumPyro model function (callable with no required args).
init_strategy: NumPyro initialization strategy name.
One of ``"init_to_median"``, ``"init_to_sample"``,
``"init_to_value"``.
chain_method: NumPyro chain execution method.
``"sequential"`` for single device, ``"parallel"`` for multi-device.
Returns:
Configured NUTSSampler ready for :meth:`run`.
"""
effective_chain_method = (
chain_method if chain_method is not None else plan.chain_method
)
init_fn_map: dict[str, Callable[..., Any]] = {
"init_to_median": numpyro_init.init_to_median,
"init_to_sample": numpyro_init.init_to_sample,
"init_to_value": numpyro_init.init_to_value,
}
init_factory = init_fn_map.get(init_strategy, numpyro_init.init_to_median)
kernel = NUTS(
model,
target_accept_prob=plan.target_accept,
max_tree_depth=plan.max_tree_depth,
dense_mass=plan.dense_mass,
adapt_step_size=plan.adapt_step_size,
init_strategy=init_factory(),
)
mcmc = MCMC(
kernel,
num_warmup=plan.num_warmup,
num_samples=plan.num_samples,
num_chains=plan.num_chains,
chain_method=effective_chain_method,
progress_bar=True,
)
logger.info(
"NUTSSampler created: %d chains, %d warmup, %d samples, "
"target_accept=%s, chain_method=%s",
plan.num_chains,
plan.num_warmup,
plan.num_samples,
plan.target_accept,
effective_chain_method,
)
return cls(mcmc, plan)
[docs]
def run(
self,
rng_key: jnp.ndarray | None = None,
init_params: dict[str, jnp.ndarray] | None = None,
) -> dict[str, Any]:
"""Run MCMC sampling.
If ``init_params`` are provided, small random perturbations are
added per chain to break symmetry and improve exploration.
Args:
rng_key: JAX PRNG key. If ``None``, one is generated from
the plan's seed.
init_params: Optional initial values for each chain. Keys
are parameter names; values should be scalars or have
shape ``(num_chains,)``.
Returns:
Dictionary of posterior samples (ungrouped).
Raises:
RuntimeError: If sampling fails.
"""
seed = self._plan.effective_seed
if rng_key is None:
rng_key = jax.random.PRNGKey(seed)
# Apply perturbation to break chain symmetry
perturbed_params = None
if init_params is not None:
perturbed_params = _perturb_init_params(
init_params,
num_chains=self._plan.num_chains,
seed=seed + 1,
)
logger.info("NUTSSampler: starting sampling (seed=%d)", seed)
# Homodyne CMC parity: capture per-step accept_prob, num_steps, and
# potential_energy alongside divergence/energy so downstream
# diagnostics (BFMI, tree-depth analysis, accept-prob stats) work.
self._mcmc.run(
rng_key,
init_params=perturbed_params,
extra_fields=(
"energy",
"diverging",
"accept_prob",
"num_steps",
"potential_energy",
),
)
# Block until JAX lazy evaluation completes so wall_time_seconds
# reflects true compute time, not deferred device_get() overhead.
jax.block_until_ready(self._mcmc.last_state)
self._has_run = True
samples = self._mcmc.get_samples()
logger.info("NUTSSampler: sampling complete")
return dict(samples)
[docs]
def run_with_init_values(
self,
init_values: dict[str, float],
rng_key: jnp.ndarray | None = None,
) -> dict[str, Any]:
"""Run MCMC seeded from NLSQ warm-start values.
Validates that the initial log density is finite before launching
full sampling, raising early with a diagnostic message if not.
Args:
init_values: NLSQ MAP estimates keyed by parameter name. Values
should be in the same space as the NumPyro model samples
(physics space, or reparameterized space if the model uses
reparameterization).
rng_key: JAX PRNG key. Generated from the plan seed if ``None``.
Returns:
Dictionary of posterior samples (ungrouped).
Raises:
RuntimeError: If the initial log density is not finite or if
sampling itself fails.
"""
seed = self._plan.effective_seed
if rng_key is None:
rng_key = jax.random.PRNGKey(seed)
# Convert scalar values to JAX arrays for compatibility with NumPyro
init_params: dict[str, jnp.ndarray] = {
name: jnp.asarray(val) for name, val in init_values.items()
}
# Preflight: check that the init point has finite log density.
self._validate_init_log_density(init_params, rng_key)
logger.info(
"NUTSSampler.run_with_init_values: warm-starting from %d NLSQ parameters",
len(init_values),
)
return self.run(rng_key=rng_key, init_params=init_params)
def _validate_init_log_density(
self,
init_params: dict[str, jnp.ndarray],
rng_key: jnp.ndarray,
) -> None:
"""Check that the initial parameter point yields a finite log density.
NumPyro's NUTS will silently produce NaN chains when the initial
point is outside the support of the model. This preflight surfaces
such issues as a ``RuntimeError`` before the full sampling run.
Args:
init_params: Initial parameter values (same format as
``init_params`` accepted by :meth:`run`).
rng_key: JAX PRNG key passed to the MCMC potential energy
evaluation.
Raises:
RuntimeError: If the log density at ``init_params`` is not finite.
"""
try:
# Use numpyro.infer.util.log_density to evaluate the model's
# unnormalized log-joint at init_params without running full NUTS.
# This works before any sampling and is the correct API — relying
# on kernel._potential_fn was wrong because it is None pre-run.
from numpyro.infer.util import log_density as _numpyro_log_density
kernel = self._mcmc.sampler # type: ignore[attr-defined]
model_fn = kernel.model # type: ignore[attr-defined]
if model_fn is None:
logger.debug(
"_validate_init_log_density: model not available on kernel, skipping"
)
return
log_density_val, _ = _numpyro_log_density(model_fn, (), {}, init_params)
log_density = float(log_density_val)
if not math.isfinite(log_density):
param_summary = ", ".join(
f"{k}={float(v):.4g}" for k, v in init_params.items()
)
raise RuntimeError(
f"Initial log density is not finite ({log_density}) at "
f"NLSQ warm-start point: {param_summary}. "
"Check that init values lie within the model's prior support."
)
logger.debug(
"_validate_init_log_density: log_density=%.4g (finite)", log_density
)
except RuntimeError:
raise
except Exception as exc: # noqa: BLE001
# Non-critical: potential_fn probing can fail for various internal
# NumPyro reasons. Log and continue rather than blocking sampling.
logger.debug("_validate_init_log_density: probe failed (%s), skipping", exc)
[docs]
def get_divergence_stats(self) -> dict[str, float]:
"""Extract divergence rate and tree-depth statistics from the last run.
Requires that :meth:`run` or :meth:`run_with_init_values` has been
called.
Returns:
Dictionary with the following keys:
``"divergence_rate"``
Fraction of post-warmup transitions that were divergent.
Zero when no divergences were recorded.
``"mean_tree_depth"``
Mean NUTS tree depth across all post-warmup samples and
chains. Values near ``max_tree_depth`` indicate the
trajectory is being truncated.
``"max_tree_depth_fraction"``
Fraction of samples that hit the maximum tree depth
(``plan.max_tree_depth``).
Raises:
RuntimeError: If called before :meth:`run`.
"""
if not self._has_run:
raise RuntimeError("Cannot extract divergence stats before calling run()")
import numpy as np
extra = self._mcmc.get_extra_fields()
# Divergences: shape (num_samples, num_chains) or flattened
div_key = "diverging"
divergence_rate = 0.0
if div_key in extra:
div = np.asarray(extra[div_key], dtype=bool)
divergence_rate = float(np.mean(div))
# Tree depth: stored as "num_steps" (number of leapfrog steps = 2^depth)
# NumPyro stores the actual tree depth in "tree_depth" extra field when
# available, otherwise fall back to estimating from "num_steps".
mean_tree_depth = float("nan")
max_depth_fraction = float("nan")
if "tree_depth" in extra:
depths = np.asarray(extra["tree_depth"], dtype=float)
mean_tree_depth = float(np.mean(depths))
max_depth_fraction = float(np.mean(depths >= self._plan.max_tree_depth))
elif "num_steps" in extra:
# num_steps = 2^tree_depth for NUTS binary tree; invert to get depth
steps = np.asarray(extra["num_steps"], dtype=float)
# Guard against zero steps (shouldn't occur, but be safe)
steps = np.where(steps > 0, steps, 1.0)
depths = np.log2(steps)
mean_tree_depth = float(np.mean(depths))
max_depth_fraction = float(np.mean(depths >= self._plan.max_tree_depth))
stats: dict[str, float] = {
"divergence_rate": divergence_rate,
"mean_tree_depth": mean_tree_depth,
"max_tree_depth_fraction": max_depth_fraction,
}
logger.debug(
"get_divergence_stats: divergence_rate=%.4f, mean_tree_depth=%.2f, "
"max_depth_fraction=%.4f",
divergence_rate,
mean_tree_depth,
max_depth_fraction,
)
return stats
[docs]
def get_diagnostics(self) -> az.InferenceData:
"""Extract ArviZ InferenceData for convergence diagnostics.
Returns:
ArviZ InferenceData containing posterior samples, sample
stats (energy, divergences), and warmup statistics.
Raises:
RuntimeError: If called before :meth:`run`.
"""
if not self._has_run:
raise RuntimeError("Cannot extract diagnostics before calling run()")
return az.from_numpyro(self._mcmc)
[docs]
def log_adapter_diagnostics(self, run_logger: Any | None = None) -> None:
"""Log NUTS adapter state at INFO level — homodyne CMC parity helper.
Reports the adapted ``step_size`` per chain, a compact summary of
the adapted inverse mass matrix, and per-step ``accept_prob`` /
``num_steps`` / ``potential_energy`` statistics when the
corresponding extra fields are present.
Args:
run_logger: Logger to emit through. ``None`` uses this
module's logger.
Raises:
RuntimeError: If called before :meth:`run`.
"""
if not self._has_run:
raise RuntimeError(
"Cannot extract adapter diagnostics before calling run()"
)
log_inst = run_logger if run_logger is not None else logger
last_state = getattr(self._mcmc, "last_state", None)
adapt_states = _extract_adapt_states(last_state)
step_sizes = _extract_step_sizes(adapt_states)
if step_sizes:
log_inst.info(
"Adapted step sizes (per chain): %s",
", ".join(f"{s:.3g}" for s in step_sizes),
)
if adapt_states:
first = adapt_states[0]
inv_mass = getattr(first, "inverse_mass_matrix", None)
if inv_mass is None and isinstance(first, dict):
inv_mass = first.get("inverse_mass_matrix")
if inv_mass is not None:
log_inst.info(
"Inverse mass matrix: %s",
_summarize_inverse_mass_matrix(inv_mass),
)
extra = self._mcmc.get_extra_fields()
for field in ("accept_prob", "num_steps", "potential_energy"):
if field in extra:
_log_array_stats(log_inst, name=field, arr=extra[field])
@property
def mcmc(self) -> MCMC:
"""Access the underlying NumPyro MCMC object."""
return self._mcmc
# ---------------------------------------------------------------------------
# Adaptive sampling plan
# ---------------------------------------------------------------------------
[docs]
@dataclass
class AdaptiveSamplingPlan:
"""Sampling plan that adjusts warmup/sample counts based on shard size.
Wraps a base :class:`SamplingPlan` and scales it down proportionally
when the shard is smaller than a reference size, while respecting
floors derived from parameter count.
The scaling rule is::
scale = sqrt(shard_size / reference_shard_size)
num_warmup = max(min_warmup_floor, int(base.num_warmup * scale))
num_samples = max(min_samples_floor, int(base.num_samples * scale))
where ``min_warmup_floor = max(50, 5 * n_params)`` and
``min_samples_floor = max(100, 10 * n_params)``.
Attributes:
base_plan: Base :class:`SamplingPlan` for a full-size shard.
shard_size: Number of data points in this shard.
n_params: Number of varying model parameters. Used to set
minimum sample-count floors.
"""
base_plan: SamplingPlan
shard_size: int
n_params: int
#: Reference shard size at which no scaling is applied.
_reference_shard_size: int = 10_000
def __post_init__(self) -> None:
if self.shard_size <= 0:
raise ValueError(f"shard_size must be > 0, got {self.shard_size}")
if self.n_params <= 0:
raise ValueError(f"n_params must be > 0, got {self.n_params}")
[docs]
def get_plan(self) -> SamplingPlan:
"""Return a :class:`SamplingPlan` adjusted for this shard.
Scaling is sub-linear (square-root) so that small shards still
receive enough samples to characterise the posterior, while large
shards are not penalised by excessive warmup.
The floor on warmup is ``max(50, 5 * n_params)`` — enough adaptation
steps to approximate the mass matrix for a 14-parameter model
(floor = 70 steps). The floor on samples is
``max(100, 10 * n_params)`` — enough draws for basic ESS diagnostics.
Returns:
Scaled :class:`SamplingPlan`.
"""
scale = min(
1.0,
math.sqrt(self.shard_size / self._reference_shard_size),
)
# Parameter-aware floors
min_warmup = max(50, 5 * self.n_params)
min_samples = max(100, 10 * self.n_params)
new_warmup = max(min_warmup, int(self.base_plan.num_warmup * scale))
new_samples = max(min_samples, int(self.base_plan.num_samples * scale))
# Rule 12: floor=70 here is too low for dense-mass NUTS on 14 params.
new_warmup = effective_warmup_floor(
new_warmup,
dense_mass=self.base_plan.dense_mass,
fast_warmup=self.base_plan.fast_warmup,
)
logger.debug(
"AdaptiveSamplingPlan.get_plan: shard_size=%d, n_params=%d, "
"scale=%.3f, num_warmup=%d, num_samples=%d",
self.shard_size,
self.n_params,
scale,
new_warmup,
new_samples,
)
return SamplingPlan(
num_warmup=new_warmup,
num_samples=new_samples,
num_chains=self.base_plan.num_chains,
target_accept=self.base_plan.target_accept,
max_tree_depth=self.base_plan.max_tree_depth,
adapt_step_size=self.base_plan.adapt_step_size,
dense_mass=self.base_plan.dense_mass,
chain_method=self.base_plan.chain_method,
seed=self.base_plan.seed,
fast_warmup=self.base_plan.fast_warmup,
)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Divergence rate thresholds
# ---------------------------------------------------------------------------
#: Target divergence rate — below this level the run is considered healthy.
DIVERGENCE_RATE_TARGET: float = 0.01
#: Elevated divergence rate — triggers a retry in :func:`run_nuts_with_retry`.
DIVERGENCE_RATE_HIGH: float = 0.05
#: Critical divergence rate — posterior geometry is likely incompatible with HMC.
DIVERGENCE_RATE_CRITICAL: float = 0.10
# ---------------------------------------------------------------------------
# SamplingStats dataclass
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class SamplingStats:
"""Summary statistics from a completed NUTS sampling run.
Attributes:
num_samples: Number of posterior draws per chain (post-warmup).
num_warmup: Number of warmup steps per chain.
num_divergences: Total divergent transitions across all chains.
divergence_rate: Fraction of post-warmup transitions that diverged.
mean_accept_prob: Mean Metropolis acceptance probability.
max_tree_depth_fraction: Fraction of samples that hit the maximum
NUTS tree depth.
wall_time_seconds: Elapsed wall-clock time for the run.
"""
num_samples: int
num_warmup: int
num_divergences: int
divergence_rate: float
mean_accept_prob: float
max_tree_depth_fraction: float
wall_time_seconds: float
@property
def is_healthy(self) -> bool:
"""Return True when divergence rate and acceptance probability are acceptable.
Criteria:
* ``divergence_rate < 0.05`` (below :data:`DIVERGENCE_RATE_HIGH`)
* ``mean_accept_prob > 0.6``
"""
return (
self.divergence_rate < DIVERGENCE_RATE_HIGH and self.mean_accept_prob > 0.6
)
# ---------------------------------------------------------------------------
# Retry wrapper
# ---------------------------------------------------------------------------
[docs]
def run_nuts_with_retry(
sampler: NUTSSampler,
model_fn: Any,
model_kwargs: dict[str, Any],
max_retries: int = 3,
target_accept_increment: float = 0.05,
*,
step_size_factor: float | None = None,
) -> tuple[dict[str, Any], SamplingStats]:
"""Run NUTS sampling with automatic step-size reduction on high divergence.
Executes :meth:`~NUTSSampler.run` and checks the divergence rate
after each attempt. When the rate exceeds
:data:`DIVERGENCE_RATE_HIGH`, a new :class:`NUTSSampler` is built
with ``target_accept`` RAISED by ``target_accept_increment`` (which
drives dual averaging toward a SMALLER step size — the
mathematically correct response to high divergence rate) and the
run is retried. After ``max_retries`` attempts the result with the
lowest divergence rate is returned regardless of health. Mirrors
homodyne ``run_nuts_with_retry`` (sampler.py:1311-1314).
The ``model_fn`` is re-used across retries so it must be stateless
(i.e. a pure NumPyro model function with no side effects).
Args:
sampler: Configured :class:`NUTSSampler` for the first attempt.
model_fn: NumPyro model callable. Not called directly here but
passed to :meth:`NUTSSampler.from_plan` for retry instances.
model_kwargs: Keyword arguments forwarded to the model via
:meth:`~NUTSSampler.run`. Currently unused by
:meth:`~NUTSSampler.run` (which takes ``rng_key`` and
``init_params``); included for forward compatibility.
max_retries: Maximum number of additional attempts after the
first run. Total runs = ``max_retries + 1``.
target_accept_increment: Additive increase applied to
``target_accept`` each retry (e.g. 0.80 → 0.85 → 0.90 →
0.95). Raising the target acceptance LOWERS the
dual-averaging step size, which is the mathematically
correct response to high divergence rates. Must be in
``(0, 0.5)``. The target is clamped at ``0.99`` to avoid
pathological tiny step sizes.
step_size_factor: DEPRECATED keyword-only alias. Earlier
versions multiplied ``target_accept`` by this factor
(with factor ``< 1``) on retry, which drove dual averaging
toward a LARGER step size — the OPPOSITE of what high
divergence rate calls for. When supplied,
``target_accept_increment`` is derived as
``(1 - step_size_factor) * 0.1`` so legacy callers see a
corrected mathematical direction.
Returns:
Tuple of ``(samples_dict, SamplingStats)`` for the best attempt
(lowest divergence rate).
"""
import time
import numpy as np
if step_size_factor is not None:
# Legacy alias: map a multiplicative factor < 1 to a positive
# increment so the direction is correct. E.g. legacy
# step_size_factor=0.5 ⇒ increment 0.05.
target_accept_increment = max(0.01, (1.0 - step_size_factor) * 0.1)
logger.warning(
"run_nuts_with_retry: 'step_size_factor' is deprecated and its "
"original direction was inverted; mapped to "
"target_accept_increment=%.4f",
target_accept_increment,
)
if not (0.0 < target_accept_increment < 0.5):
raise ValueError(
f"target_accept_increment must be in (0, 0.5), got "
f"{target_accept_increment}"
)
best_samples: dict[str, Any] | None = None
best_stats: SamplingStats | None = None
best_divergence_rate = float("inf")
current_sampler = sampler
current_target_accept = sampler.plan.target_accept
for attempt in range(max_retries + 1):
t_start = time.monotonic()
samples = current_sampler.run()
wall_time = time.monotonic() - t_start
div_stats = current_sampler.get_divergence_stats()
divergence_rate = div_stats["divergence_rate"]
max_tree_depth_fraction = div_stats.get("max_tree_depth_fraction", float("nan"))
# Extract mean acceptance probability from extra fields
extra = current_sampler.mcmc.get_extra_fields()
mean_accept_prob = 0.0
if "mean_accept_prob" in extra:
mean_accept_prob = float(np.mean(np.asarray(extra["mean_accept_prob"])))
elif "accept_prob" in extra:
mean_accept_prob = float(np.mean(np.asarray(extra["accept_prob"])))
plan = current_sampler.plan
n_divergent = int(round(divergence_rate * plan.num_samples * plan.num_chains))
stats = SamplingStats(
num_samples=plan.num_samples,
num_warmup=plan.num_warmup,
num_divergences=n_divergent,
divergence_rate=divergence_rate,
mean_accept_prob=mean_accept_prob,
max_tree_depth_fraction=max_tree_depth_fraction,
wall_time_seconds=wall_time,
)
logger.info(
"run_nuts_with_retry attempt %d/%d: "
"divergence_rate=%.4f, mean_accept_prob=%.4f, wall_time=%.1fs",
attempt + 1,
max_retries + 1,
divergence_rate,
mean_accept_prob,
wall_time,
)
if divergence_rate < best_divergence_rate:
best_divergence_rate = divergence_rate
best_samples = samples
best_stats = stats
# Stop early if divergence rate is acceptable
if divergence_rate <= DIVERGENCE_RATE_HIGH:
break
if attempt < max_retries:
# Raising target_accept drives dual averaging toward a SMALLER
# step size, which is the mathematically correct response to
# high divergence rate. The previous implementation reduced
# target_accept (silently making divergence WORSE).
current_target_accept = min(
0.99, current_target_accept + target_accept_increment
)
logger.warning(
"run_nuts_with_retry: divergence_rate=%.4f > %.2f; "
"retrying with target_accept=%.4f (attempt %d/%d)",
divergence_rate,
DIVERGENCE_RATE_HIGH,
current_target_accept,
attempt + 2,
max_retries + 1,
)
# Build a new sampler with raised target acceptance
new_plan = SamplingPlan(
num_warmup=plan.num_warmup,
num_samples=plan.num_samples,
num_chains=plan.num_chains,
target_accept=current_target_accept,
max_tree_depth=plan.max_tree_depth,
adapt_step_size=plan.adapt_step_size,
dense_mass=plan.dense_mass,
chain_method=plan.chain_method,
seed=plan.seed,
)
current_sampler = NUTSSampler.from_plan(
new_plan,
model_fn,
)
else:
# Exhausted retries
logger.warning(
"run_nuts_with_retry: exhausted %d retries; "
"returning best result with divergence_rate=%.4f",
max_retries,
best_divergence_rate,
)
# These are always set on the first iteration, so cannot be None here
assert best_samples is not None # noqa: S101
assert best_stats is not None # noqa: S101
return best_samples, best_stats
def _compute_mcmc_safe_d0_component(
d0: float | None,
alpha: float | None,
d_offset: float | None,
*,
q: float,
dt: float,
time_grid: Any | None,
target_g1: float = 0.5,
g1_threshold: float = 0.1,
) -> tuple[float, float] | None:
"""Detect vanishing-gradient pathology for a single transport component.
The heterodyne signal is a mixture of two single-exponential
contributions :math:`g_1 = f \\exp(-q^2 J_s) + (1-f) \\exp(-q^2 J_r)`.
Each component can independently kill the gradient if its diffusion
integral grows large; this helper inspects one component's
``(D0, alpha, D_offset)`` triple and returns a rescaled
``(D0', D_offset')`` pair when ``g1`` falls below ``g1_threshold``.
Mirrors homodyne ``_compute_mcmc_safe_d0`` (sampler.py:143-278) but
is adapted for heterodyne's two-channel architecture by being called
twice (once per ``ref``/``sample`` channel).
Args:
d0: Diffusion prefactor (e.g. ``D0_ref`` or ``D0_sample``).
alpha: Power-law exponent on the diffusion term.
d_offset: Additive offset to the diffusion rate.
q: Wavevector magnitude (Å⁻¹).
dt: Lag-time step (s).
time_grid: Time grid for integration. ``None`` → return ``None``.
target_g1: ``g1`` value the scaling targets when adjustment fires.
g1_threshold: ``g1`` floor below which an adjustment fires.
Returns:
``(new_d0, new_d_offset)`` when an MCMC-safe scaling is required;
``None`` when the original values are safe or cannot be evaluated.
"""
import numpy as _np
if d0 is None or alpha is None or d_offset is None:
return None
if time_grid is None or len(time_grid) < 2:
return None
if not (_np.isfinite(d0) and _np.isfinite(alpha) and _np.isfinite(d_offset)):
return None
try:
epsilon = 1e-10
time_safe = _np.asarray(time_grid) + epsilon
# Homodyne parity: use a gradient-safe floor for the integrand.
# jnp.maximum would zero the gradient under the floor; np.maximum
# is safe here because we evaluate at concrete values outside the
# JAX tracer.
d_grid = d0 * (time_safe**alpha) + d_offset
d_grid = _np.maximum(d_grid, 1e-10)
if len(d_grid) > 1:
trap_avg = 0.5 * (d_grid[:-1] + d_grid[1:])
cumsum = _np.concatenate([[0.0], _np.cumsum(trap_avg)])
else:
cumsum = _np.cumsum(d_grid)
n = len(cumsum)
integral_estimate = abs(cumsum[3 * n // 4] - cumsum[n // 4])
prefactor = q**2 * dt
log_g1 = -prefactor * integral_estimate
log_g1_clipped = max(log_g1, -700.0)
g1_estimate = _np.exp(log_g1_clipped)
if g1_estimate >= g1_threshold:
return None
target_log_g1 = _np.log(target_g1)
target_integral = -target_log_g1 / prefactor
scale_factor = (
target_integral / integral_estimate if integral_estimate > 0 else 0.01
)
new_d0 = max(float(d0 * scale_factor), 1.0)
new_d_offset = max(float(d_offset * scale_factor), -1e6)
return new_d0, new_d_offset
except Exception: # noqa: BLE001 — diagnostics helper must not crash MCMC
return None
[docs]
def compute_mcmc_safe_initial_values(
initial_values: dict[str, float] | None,
*,
q: float,
dt: float,
time_grid: Any | None,
target_g1: float = 0.5,
g1_threshold: float = 0.1,
) -> dict[str, float] | None:
"""Detect/repair vanishing-g1 initial parameters for heterodyne NUTS.
Inspects both the reference and sample transport components and, when
either drives ``g1`` below ``g1_threshold`` at a typical lag, rescales
its ``(D0, D_offset)`` pair so that ``g1 ≈ target_g1``. Returns
``None`` when no adjustment is needed so callers can short-circuit.
Args:
initial_values: NLSQ warm-start dictionary. Expected keys include
``D0_ref``, ``alpha_ref``, ``D_offset_ref``, ``D0_sample``,
``alpha_sample``, ``D_offset_sample``.
q: Wavevector magnitude (Å⁻¹).
dt: Lag-time step (s).
time_grid: Time grid for integration.
target_g1: Target ``g1`` value when rescaling.
g1_threshold: Threshold below which rescaling fires.
Returns:
A new dictionary with adjusted ``D0_*`` / ``D_offset_*`` values
when adjustment is required, else ``None``.
"""
if initial_values is None:
return None
adjusted: dict[str, float] = dict(initial_values)
any_change = False
for prefix in ("ref", "sample"):
d0_key = f"D0_{prefix}"
a_key = f"alpha_{prefix}"
off_key = f"D_offset_{prefix}"
result = _compute_mcmc_safe_d0_component(
adjusted.get(d0_key),
adjusted.get(a_key),
adjusted.get(off_key),
q=q,
dt=dt,
time_grid=time_grid,
target_g1=target_g1,
g1_threshold=g1_threshold,
)
if result is not None:
new_d0, new_off = result
logger.warning(
"compute_mcmc_safe_initial_values: %s pathology — "
"scaling %s %.4g -> %.4g, %s %.4g -> %.4g for MCMC stability",
prefix,
d0_key,
adjusted[d0_key],
new_d0,
off_key,
adjusted[off_key],
new_off,
)
adjusted[d0_key] = new_d0
adjusted[off_key] = new_off
any_change = True
return adjusted if any_change else None
def _perturb_init_params(
init_params: dict[str, jnp.ndarray],
num_chains: int,
seed: int,
perturbation_scale: float = 0.01,
) -> dict[str, jnp.ndarray]:
"""Add small per-chain perturbations to initial parameters.
Ensures each chain starts at a slightly different point in parameter
space, preventing degenerate identical chains that waste compute.
Args:
init_params: Base initial values. Scalars are broadcast to
``(num_chains,)``.
num_chains: Number of MCMC chains.
seed: Random seed for perturbation generation.
perturbation_scale: Standard deviation of additive Gaussian
perturbation relative to parameter magnitude.
Returns:
New dict with perturbed values of shape ``(num_chains,)``.
"""
perturbed: dict[str, jnp.ndarray] = {}
rng_key = jax.random.PRNGKey(seed)
param_names = list(init_params.keys())
subkeys = jax.random.split(rng_key, num=len(param_names))
for i, name in enumerate(param_names):
value = init_params[name]
# Ensure shape (num_chains,)
base = jnp.broadcast_to(jnp.asarray(value), (num_chains,))
magnitude = jnp.abs(base) + 1e-10 # floor for zero-valued params
noise = (
perturbation_scale
* magnitude
* jax.random.normal(subkeys[i], shape=(num_chains,))
)
perturbed[name] = base + noise
return perturbed