"""Core CMC fitting functions for heterodyne Bayesian analysis.
Includes the original single-run ``fit_cmc_jax`` and the new sharded
Consensus Monte Carlo entry point ``fit_cmc_sharded``, plus all supporting
helpers for shard creation, prior tempering, and posterior combination.
"""
from __future__ import annotations
import math
import os
import secrets
import time
import warnings
from collections.abc import Generator, Iterable
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, cast
import arviz as az
import jax
import jax.numpy as jnp
import numpy as np
from numpyro.infer import MCMC, NUTS
from heterodyne.config.parameter_registry import DEFAULT_REGISTRY
from heterodyne.optimization.cmc.config import CMCConfig
from heterodyne.optimization.cmc.data_prep import (
PooledCMCData,
prepare_mcmc_data,
shard_pooled_angle_balanced,
shard_pooled_random,
)
from heterodyne.optimization.cmc.diagnostics import (
analyze_divergences,
log_analysis_summary,
validate_convergence,
)
from heterodyne.optimization.cmc.model import (
estimate_sigma,
get_heterodyne_model,
get_heterodyne_model_reparam,
get_heterodyne_pooled_model_for_mode,
)
from heterodyne.optimization.cmc.priors import (
build_default_priors,
build_nlsq_informed_priors,
temper_priors,
)
from heterodyne.optimization.cmc.reparameterization import (
ReparamConfig,
compute_t_ref,
transform_nlsq_to_reparam_space,
transform_to_physics_space,
)
from heterodyne.optimization.cmc.results import CMCResult
from heterodyne.optimization.nlsq.results import NLSQResult
from heterodyne.utils.logging import get_logger
if TYPE_CHECKING:
from heterodyne.config.parameter_space import ParameterSpace
from heterodyne.core.heterodyne_model import HeterodyneModel
logger = get_logger(__name__)
@contextmanager
def _silence_arviz_diagnostic_warnings() -> Generator[None]:
"""Silence two third-party warnings that fire on every CMC diagnostic call.
1. ``arviz_stats`` ``UserWarning``: "Computing filter_vars on DataTree
named None which doesn't match the group argument sample_stats" —
cosmetic naming concern in arviz's group resolution that doesn't
affect the computed values.
2. ``arviz_stats`` ``RuntimeWarning``: "invalid value encountered in
scalar divide" — energy / R-hat variance is zero on degenerate
chains; the resulting NaN is propagated and handled downstream by
:func:`_extract_posterior_stats` and the divergence-rate gate.
"""
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=".*filter_vars on DataTree named None.*",
)
warnings.filterwarnings(
"ignore",
category=RuntimeWarning,
message="invalid value encountered in scalar divide",
)
yield
# ---------------------------------------------------------------------------
# Degenerate warm-start thresholds — shared with optimization_runner.py.
# Defined here (authoritative source) so both modules stay in sync.
# ---------------------------------------------------------------------------
#: Sample fraction below which alpha_sample/D0_sample/D_offset_sample are
#: unidentifiable, causing 100% shard failure (het_bb97531f failure mode).
CMC_F0_DEGEN_THRESHOLD: float = 0.10
#: alpha_sample value at which J_sample(t) ∝ t^α has a non-integrable
#: singularity at t→0, collapsing NUTS step-size for the sample group.
CMC_ALPHA_SINGULARITY: float = -1.5
#: Target value for ``f0`` when auto-clamping a degenerate warm-start.
#: Set to 2× the degeneracy threshold (not just above it) so NUTS leapfrog
#: steps cannot easily reflect back into the unidentifiable region.
#: The borderline value ``CMC_F0_DEGEN_THRESHOLD + 0.01 = 0.11`` used in
#: het_a10cf27e produced 47/47 shard failure even after clamping; this
#: wider margin is the het_a10cf27e fix.
CMC_F0_SAFE_ZONE: float = 0.20
#: Target value for ``alpha_sample`` when auto-clamping. Sits 0.5 above the
#: t^α singularity (vs the previous 0.1 borderline value). The wider margin
#: keeps NUTS step-size adaptation stable when warm-starting from the
#: degenerate region.
CMC_ALPHA_SAFE_ZONE: float = -1.0
def _block_until_ready_pytree(tree: Any) -> Any:
"""Block every JAX array leaf in a pytree and return the original object."""
for leaf in jax.tree_util.tree_leaves(tree):
block_until_ready = getattr(leaf, "block_until_ready", None)
if block_until_ready is not None:
block_until_ready()
return tree
# ---------------------------------------------------------------------------
# Public: original single-run entry point (signature preserved exactly)
# ---------------------------------------------------------------------------
def _validate_config_or_raise(config: CMCConfig) -> None:
"""Run ``CMCConfig.validate()`` and raise ``ValueError`` on any error.
Centralises Rule-12 enforcement (dense-mass warmup floor) and the rest of
the per-field consistency checks so that no public CMC entry point can
silently accept a bad config. ``CMCConfig.validate()`` itself only
returns the error list — callers are responsible for raising.
"""
errors = config.validate()
if errors:
raise ValueError("Invalid CMCConfig:\n - " + "\n - ".join(errors))
[docs]
def fit_cmc_jax(
model: HeterodyneModel,
c2_data: np.ndarray | jnp.ndarray,
phi_angle: float = 0.0,
config: CMCConfig | None = None,
sigma: np.ndarray | float | None = None,
nlsq_result: NLSQResult | None = None,
t_override: np.ndarray | None = None,
priors_override: dict | None = None,
prior_width_multiplier: float = 1.0,
) -> CMCResult:
"""Fit heterodyne model using Consensus Monte Carlo.
Uses NumPyro's NUTS sampler for Bayesian posterior inference.
Args:
model: HeterodyneModel with configured parameters
c2_data: Observed correlation data
phi_angle: Detector phi angle (degrees)
config: CMC configuration (default if None)
sigma: Measurement uncertainty (estimated if None)
nlsq_result: Optional NLSQ result for warm-starting
t_override: Optional time array replacing ``model.t`` for model
construction. Used by ``fit_cmc_sharded`` to pass shard time
slices. If ``None``, falls back to ``model.t``.
priors_override: Optional dict of pre-built NumPyro distributions
keyed by parameter name. When provided, these distributions
replace the default ``space.priors`` for matching parameters.
Used by ``fit_cmc_sharded`` to inject tempered shard priors
into the non-reparam model path.
prior_width_multiplier: Scalar multiplier applied to the ``scale``
of each reparam-path prior AFTER ``nlsq_prior_width_factor``
scaling. Default 1.0 (no change). Used by
``fit_cmc_sharded`` to widen reparam priors by ``sqrt(K)``.
Returns:
CMCResult with posterior samples and diagnostics
"""
if config is None:
config = CMCConfig()
_validate_config_or_raise(config)
t_for_model = jnp.asarray(t_override) if t_override is not None else model.t
logger.info(
"[CMC] Starting analysis: chains=%d, samples=%d, warmup=%d",
config.num_chains,
config.num_samples,
config.num_warmup,
)
start_time = time.perf_counter()
# --- Phase 1: data preparation ---
logger.info("[CMC] Phase 1/4: data preparation")
c2_jax = jnp.asarray(c2_data)
_n_total = int(c2_jax.size)
_MAX_SINGLE_SHARD = 100_000
if _n_total > _MAX_SINGLE_SHARD:
logger.warning(
"[CMC] Single-shard data has %d points (> %d). "
"NUTS is O(n) per leapfrog step — consider fit_cmc_sharded for large datasets.",
_n_total,
_MAX_SINGLE_SHARD,
)
sigma_resolved: np.ndarray | jnp.ndarray | float
if sigma is None:
sigma_resolved = estimate_sigma(c2_jax, method="diagonal")
logger.info(
"[CMC] Estimated sigma = %.4e", float(jnp.mean(jnp.asarray(sigma_resolved)))
)
else:
sigma_resolved = sigma
sigma_jax = (
jnp.asarray(sigma_resolved)
if isinstance(sigma_resolved, np.ndarray)
else sigma_resolved
)
# Scalar prior centre for the sampled sigma site (homodyne parity).
noise_scale = float(jnp.mean(jnp.asarray(sigma_jax)))
# --- Phase 2: model construction ---
logger.info("[CMC] Phase 2/4: model construction")
space = model.param_manager.space
varying_names = model.param_manager.varying_names
# Read fitted contrast/offset from model scaling (angle_idx=0 for per-angle CMC).
# After NLSQ, model.scaling holds the fitted values; passing 1.0 defaults would
# silently use the wrong scaling and bias the entire posterior.
contrast, offset = model.scaling.get_for_angle(0)
logger.info(
"[CMC] Using contrast=%.4f, offset=%.4f from model scaling",
contrast,
offset,
)
logger.info("[CMC] Sampling %d parameters: %s", len(varying_names), varying_names)
# Validate NLSQ warm-start
use_reparam = config.use_reparam and nlsq_result is not None and nlsq_result.success
# priors_override is a hard contract: caller wants these exact distributions.
# The reparam path samples in z-space and ignores override → silent data loss.
# Force the non-reparam path so the override is actually used.
if priors_override is not None and use_reparam:
logger.info(
"[CMC] priors_override provided; disabling reparameterization so the "
"caller-supplied distributions are sampled directly."
)
use_reparam = False
if config.use_nlsq_warmstart and nlsq_result is None:
logger.warning(
"[CMC] NLSQ warm-start is enabled (use_nlsq_warmstart=True) but no "
"nlsq_result was provided. Chains will initialize at the prior, which "
"is typically 5-10σ from the true posterior for the 14-parameter "
"heterodyne model. This produces R-hat >> 1 and ESS ≈ n_chains — "
"effectively a failed run after hours of sampling. "
"Pass nlsq_result= or use optimizer: both in the CLI config."
)
elif (
config.use_nlsq_warmstart
and nlsq_result is not None
and not nlsq_result.success
):
logger.warning(
"[CMC] NLSQ warm-start requested but result is not converged "
"(success=False); falling back to default initialization"
)
from heterodyne.optimization.cmc.scaling import ParameterScaling
reparam_config: ReparamConfig | None = None
scalings: dict[str, ParameterScaling] | None = None
reparam_values: dict[str, float] = {}
prior_std_dict: dict[str, float] = {}
if use_reparam:
# ``use_reparam`` is only True when nlsq_result is not None and
# success=True (see set-up above) — assert narrows for pyright.
assert nlsq_result is not None
t_array = (
np.asarray(t_override) if t_override is not None else np.asarray(model.t)
)
dt_val = (
float(t_array[1] - t_array[0]) if len(t_array) > 1 else float(t_array[0])
)
t_max_val = float(t_array[-1])
t_ref = compute_t_ref(dt_val, t_max_val, fallback_value=1.0)
# reparameterization_d_total controls both D0_ref/alpha_ref and
# D0_sample/alpha_sample pairs; reparameterization_log_gamma controls v0/beta.
reparam_config = ReparamConfig(
t_ref=t_ref,
enable_d_ref=config.reparameterization_d_total,
enable_d_sample=config.reparameterization_d_total,
enable_v_ref=config.reparameterization_log_gamma,
)
logger.info(
"[CMC] Reference-time reparameterization: t_ref=%.4e "
"(d_ref=%s, d_sample=%s, v_ref=%s)",
t_ref,
config.reparameterization_d_total,
config.reparameterization_d_total,
config.reparameterization_log_gamma,
)
nlsq_values = {
name: float(nlsq_result.get_param(name))
for name in varying_names
if name in nlsq_result.parameter_names
}
nlsq_uncertainties = {}
for _unc_name in varying_names:
if _unc_name in nlsq_result.parameter_names:
_unc_val = nlsq_result.get_uncertainty(_unc_name)
if _unc_val is not None:
nlsq_uncertainties[_unc_name] = float(_unc_val)
reparam_values, reparam_uncertainties = transform_nlsq_to_reparam_space(
nlsq_values,
nlsq_uncertainties,
t_ref,
reparam_config,
)
scalings = {}
prefactor_to_log: dict[str, str] = {}
for prefactor, exponent in reparam_config.enabled_pairs:
if prefactor in varying_names and exponent in varying_names:
prefactor_to_log[prefactor] = reparam_config.get_reparam_name(prefactor)
for name in varying_names:
if name in prefactor_to_log:
sname = prefactor_to_log[name]
else:
sname = name
center = reparam_values.get(sname, space.values.get(name, 0.0))
unc = reparam_uncertainties.get(sname, 0.0)
scale = unc * config.nlsq_prior_width_factor if unc > 0 else 1.0
scale = max(scale, 1e-10)
scale = (
scale * prior_width_multiplier
) # temper reparam prior width for CMC shards
if sname.startswith("log_"):
low = center - 10.0 * scale
high = center + 10.0 * scale
else:
low, high = space.bounds[name]
scalings[sname] = ParameterScaling(
name=sname,
center=center,
scale=scale,
low=low,
high=high,
)
prior_std_dict = {}
for name in varying_names:
if name in prefactor_to_log:
sname = prefactor_to_log[name]
sc = scalings[sname]
center_physics = float(np.exp(sc.center))
prior_std_dict[name] = center_physics * sc.scale
else:
sname = name
if sname in scalings:
prior_std_dict[name] = scalings[sname].scale
# priors_override is now handled by forcing use_reparam=False above,
# so this branch only runs when no override is present.
numpyro_model = get_heterodyne_model_reparam(
t=t_for_model,
q=model.q,
dt=model.dt,
phi_angle=phi_angle,
c2_data=c2_jax,
noise_scale=noise_scale,
space=space,
reparam_config=reparam_config,
scalings=scalings,
contrast=contrast,
offset=offset,
)
else:
numpyro_model = get_heterodyne_model(
t=t_for_model,
q=model.q,
dt=model.dt,
phi_angle=phi_angle,
c2_data=c2_jax,
noise_scale=noise_scale,
space=space,
contrast=contrast,
offset=offset,
priors_override=priors_override,
)
# --- Phase 3: sampling ---
logger.info("[CMC] Phase 3/4: NUTS sampling")
from numpyro.infer import initialization as numpyro_init
_init_strategy_map = {
"init_to_median": numpyro_init.init_to_median,
"init_to_sample": numpyro_init.init_to_sample,
"init_to_value": numpyro_init.init_to_value,
}
init_fn = _init_strategy_map.get(config.init_strategy, numpyro_init.init_to_median)
# Elevate target acceptance for high-correlation regimes: when Z-space
# reparameterization is active the power-law pair geometry is decorrelated
# within each pair, but cross-pair correlations remain. A floor of 0.9
# (matching homodyne's laminar-flow policy) improves leapfrog step quality.
_MIN_TARGET_ACCEPT_REPARAM = 0.9
effective_target_accept = (
max(config.target_accept_prob, _MIN_TARGET_ACCEPT_REPARAM)
if use_reparam
else config.target_accept_prob
)
if use_reparam and effective_target_accept > config.target_accept_prob:
logger.info(
"[CMC] Elevating target_accept_prob %.2f → %.2f (reparam active)",
config.target_accept_prob,
effective_target_accept,
)
kernel = NUTS(
numpyro_model,
target_accept_prob=effective_target_accept,
max_tree_depth=config.max_tree_depth,
dense_mass=config.dense_mass,
init_strategy=init_fn(),
)
rng_seed = config.seed if config.seed is not None else secrets.randbelow(2**31)
init_params = None
if config.use_nlsq_warmstart and nlsq_result is not None and nlsq_result.success:
logger.info("[CMC] Using NLSQ result for chain initialization")
if use_reparam and scalings is not None:
init_params = {}
perturb_key = jax.random.PRNGKey(rng_seed + 1)
for sname, sc in scalings.items():
perturb_key, subkey = jax.random.split(perturb_key)
reparam_val = reparam_values.get(sname, sc.center)
z_init = sc.to_normalized(reparam_val)
base = jnp.full((config.num_chains,), jnp.float64(z_init))
perturbation = 0.01 * jax.random.normal(
subkey, shape=(config.num_chains,)
)
init_params[f"{sname}_z"] = base + perturbation
# sigma is sampled as a posterior site; initialise at prior centre so
# NUTS init_strategy doesn't call the model without a seed handler.
init_params["sigma"] = jnp.full(
(config.num_chains,), jnp.float64(noise_scale * 1.5)
)
else:
init_params = {}
perturb_key = jax.random.PRNGKey(rng_seed + 1)
for name in varying_names:
if name in nlsq_result.parameter_names:
perturb_key, subkey = jax.random.split(perturb_key)
base = jnp.full(
(config.num_chains,),
jnp.float64(nlsq_result.get_param(name)),
)
perturbation = 0.01 * jax.random.normal(
subkey, shape=(config.num_chains,)
)
init_params[name] = base + perturbation
init_params["sigma"] = jnp.full(
(config.num_chains,), jnp.float64(noise_scale * 1.5)
)
mcmc = MCMC(
kernel,
num_warmup=config.num_warmup,
num_samples=config.num_samples,
num_chains=config.num_chains,
progress_bar=True,
)
rng_key = jax.random.PRNGKey(rng_seed)
try:
mcmc.run(rng_key, init_params=init_params, extra_fields=("energy", "diverging"))
samples = _block_until_ready_pytree(mcmc.get_samples())
except (RuntimeError, ValueError) as e:
logger.error("[CMC] MCMC sampling failed: %s", e)
return _create_failed_result(varying_names, str(e))
# --- Phase 4: diagnostics and output ---
sample_count = max(
(np.asarray(values).shape[0] for values in samples.values()), default=0
)
logger.info(
"[CMC] NUTS sampling complete: collected %d posterior draws", sample_count
)
logger.info("[CMC] Phase 4/4: diagnostics and result construction")
# arviz_base.io_numpyro accesses numpyro.infer.initialization as an attribute.
# Heterodyne's import chain loads the submodule but doesn't always register it
# as a package attribute (Python import-order quirk). Set it explicitly.
import sys as _sys
import numpyro.infer as _numpyro_infer
if not hasattr(_numpyro_infer, "initialization"):
_init_mod = _sys.modules.get("numpyro.infer.initialization")
if _init_mod is not None:
setattr(_numpyro_infer, "initialization", _init_mod) # noqa: B010
idata = az.from_numpyro(mcmc)
if use_reparam and reparam_config is not None:
raw_samples = {k: np.asarray(v) for k, v in samples.items()}
physics_samples = transform_to_physics_space(raw_samples, reparam_config)
output_names = varying_names
available_names = [n for n in output_names if n in idata.posterior]
with _silence_arviz_diagnostic_warnings():
summary = (
az.summary(idata, var_names=available_names, ci_prob=0.95)
if available_names
else None
)
else:
physics_samples = {k: np.asarray(v) for k, v in samples.items()}
output_names = varying_names
with _silence_arviz_diagnostic_warnings():
summary = az.summary(idata, var_names=output_names, ci_prob=0.95)
posterior_mean, posterior_std, r_hat, ess_bulk, ess_tail = _extract_posterior_stats(
output_names,
physics_samples,
summary,
)
credible_intervals = _extract_credible_intervals(
output_names, physics_samples, summary
)
bfmi, bfmi_compute_failed = _compute_bfmi(idata)
samples_dict = {
name: physics_samples[name] for name in output_names if name in physics_samples
}
map_estimate = posterior_mean.copy()
wall_time = time.perf_counter() - start_time
r_hat_finite = r_hat[~np.isnan(r_hat)]
ess_finite = ess_bulk[~np.isnan(ess_bulk)]
convergence_passed = bool(
len(r_hat_finite) > 0
and np.all(r_hat_finite < config.max_r_hat)
and len(ess_finite) > 0
and np.all(ess_finite > config.min_ess)
)
# BFMI is advisory only (homodyne parity). Homodyne check_convergence uses
# R-hat + ESS as the sole hard gates. Applying BFMI as a hard gate here
# silently kills all shards when chains start near a parameter boundary
# (boundary reflection → low BFMI is expected and normal).
if bfmi is not None and not bfmi_compute_failed:
_min_bfmi = float(np.nanmin(np.asarray(bfmi, dtype=float)))
if _min_bfmi < config.min_bfmi:
logger.warning(
"[CMC] Low BFMI=%.3f < %.2f — poor HMC energy exploration "
"(advisory only; does not affect convergence gate)",
_min_bfmi,
config.min_bfmi,
)
if bfmi_compute_failed:
logger.debug(
"[CMC] BFMI unavailable (az.bfmi computation failed); "
"convergence determined by R-hat and ESS only"
)
metadata: dict[str, Any] = {}
# Store divergence_rate so fit_cmc_sharded can filter high-divergence shards
# before consensus combination (CM-02 fix).
_extra = mcmc.get_extra_fields()
_div = _extra.get("diverging", None)
if _div is not None:
_div_arr = np.asarray(_div, dtype=bool)
metadata["divergence_rate"] = (
float(np.mean(_div_arr)) if _div_arr.size > 0 else 0.0
)
if use_reparam and reparam_config is not None:
metadata["t_ref"] = reparam_config.t_ref
metadata["prior_std"] = prior_std_dict
result = CMCResult(
parameter_names=output_names,
posterior_mean=posterior_mean,
posterior_std=posterior_std,
credible_intervals=credible_intervals,
convergence_passed=convergence_passed,
r_hat=r_hat,
ess_bulk=ess_bulk,
ess_tail=ess_tail,
bfmi=bfmi,
samples=samples_dict,
map_estimate=map_estimate,
num_warmup=config.num_warmup,
num_samples=config.num_samples,
num_chains=config.num_chains,
wall_time_seconds=wall_time,
metadata=metadata,
)
conv_report = validate_convergence(
result, config.max_r_hat, config.min_ess, config.min_bfmi
)
for msg in conv_report.messages:
logger.info(msg)
logger.info(
"[CMC] Complete in %.1fs, convergence: %s",
wall_time,
"PASSED" if convergence_passed else "FAILED",
)
# Divergence analysis (parity with homodyne)
div_report = analyze_divergences(result)
for msg in div_report.messages:
logger.warning(msg)
# Structured analysis summary (parity with homodyne)
_r_hat_dict = (
{n: float(result.r_hat[i]) for i, n in enumerate(result.parameter_names)}
if result.r_hat is not None
else {}
)
_ess_dict = (
{n: float(result.ess_bulk[i]) for i, n in enumerate(result.parameter_names)}
if result.ess_bulk is not None
else {}
)
log_analysis_summary(
convergence_status=result.convergence_status
or ("converged" if result.convergence_passed else "not_converged"),
r_hat=_r_hat_dict,
ess_bulk=_ess_dict,
divergences=result.divergences or 0,
n_samples=config.num_samples,
n_chains=config.num_chains,
n_shards=1,
shards_succeeded=1 if result.convergence_passed else 0,
execution_time=wall_time,
)
return result
# ---------------------------------------------------------------------------
# Public: sharded CMC entry point
# ---------------------------------------------------------------------------
[docs]
def fit_cmc_sharded(
model: HeterodyneModel,
c2_data: np.ndarray | jnp.ndarray,
phi_angle: float = 0.0,
config: CMCConfig | None = None,
sigma: np.ndarray | float | None = None,
nlsq_result: NLSQResult | None = None,
num_shards: int = 4,
sharding_strategy: str = "random",
shard_seed: int | None = None,
) -> CMCResult:
"""Fit heterodyne model using sharded Consensus Monte Carlo.
Splits the observed c2 matrix into ``num_shards`` independent data
subsets, runs NUTS on each shard sub-posterior (sequentially), then
combines the shard posteriors via inverse-variance weighted consensus.
Prior tempering is applied automatically: each shard's prior distribution
is widened by ``sqrt(num_shards)`` (i.e., ``prior^(1/K)``) while sigma
is passed unscaled. This is the correct Consensus Monte Carlo approach
(Scott et al., 2016).
Args:
model: HeterodyneModel with configured parameters.
c2_data: Observed two-time correlation matrix (N x N).
phi_angle: Detector phi angle (degrees).
config: CMC configuration (defaults to CMCConfig()).
sigma: Measurement uncertainty (estimated if None).
nlsq_result: Optional NLSQ result for warm-starting each shard.
num_shards: Number of data shards (K). Must be >= 2.
sharding_strategy: One of ``"random"`` (default) or
``"contiguous"``. Random sharding breaks temporal
autocorrelation between shards. Contiguous sharding uses
diagonal time-blocks, which preserves the two-time structure
within each shard.
shard_seed: Integer seed for deterministic shard assignment.
If ``None``, a random seed is drawn from the OS.
Returns:
CMCResult with combined posterior and per-shard diagnostics stored
in ``result.metadata["shard_diagnostics"]``.
Raises:
ValueError: If inputs fail validation or ``num_shards < 2``.
"""
if config is None:
config = CMCConfig()
_validate_config_or_raise(config)
if num_shards < 2:
raise ValueError(f"num_shards must be >= 2 for sharded CMC, got {num_shards}")
# Validate inputs before touching JAX
_validate_cmc_inputs(c2_data, sigma, model.param_manager.space)
logger.info(
"[CMC-sharded] Starting: %d shards, strategy=%s, chains=%d, samples=%d",
num_shards,
sharding_strategy,
config.num_chains,
config.num_samples,
)
start_time = time.perf_counter()
# --- Phase 1: data preparation ---
logger.info("[CMC-sharded] Phase 1/5: data preparation")
c2_np = np.asarray(c2_data, dtype=np.float64)
if sigma is None:
sigma_jax_full = estimate_sigma(jnp.asarray(c2_np), method="diagonal")
sigma_np: np.ndarray | float = np.asarray(sigma_jax_full)
logger.info(
"[CMC-sharded] Estimated sigma = %.4e",
float(np.mean(np.asarray(sigma_np))),
)
else:
sigma_np = np.asarray(sigma) if not isinstance(sigma, float) else sigma
# --- Phase 2: shard creation ---
logger.info("[CMC-sharded] Phase 2/5: creating %d shards", num_shards)
effective_seed = shard_seed if shard_seed is not None else secrets.randbelow(2**31)
shards = _create_shards(
c2_np, sigma_np, num_shards, sharding_strategy, effective_seed
)
logger.info(
"[CMC-sharded] Shards created: sizes=%s",
[len(s["indices"]) for s in shards],
)
# --- Build tempered priors for CMC shards ---
# Correct CMC tempering: widen prior by sqrt(K) per shard (prior^(1/K)).
# Workers rebuild priors locally; _base_priors here is computed so that
# build_nlsq_informed_priors / build_default_priors / temper_priors are
# imported at module level so tests can patch them on this module.
# Actual tempering flows via prior_width_multiplier passed to run_shards().
_space = model.param_manager.space
_scaling_active = [
n for n in _space.varying_names if n not in _space.varying_physics_names
]
if _scaling_active:
logger.debug(
"[CMC-sharded] ParameterSpace has scaling params active (%s); "
"workers will use varying_physics_names (physics-only) for NUTS.",
_scaling_active,
)
if nlsq_result is not None and nlsq_result.success:
base_priors = build_nlsq_informed_priors(
nlsq_result, _space, width_factor=config.nlsq_prior_width_factor
)
else:
base_priors = build_default_priors(
_space,
use_log_space_priors=getattr(config, "use_log_space_priors", True),
)
_shard_priors = temper_priors(base_priors, num_shards)
prior_width_mult = math.sqrt(num_shards)
# --- Phase 3: per-shard sampling (parallel) ---
base_seed = config.seed if config.seed is not None else secrets.randbelow(2**31)
# Derive time axis from C2 matrix shape, not model.t. NLSQ trim calls
# sync_time_axis(np.arange(1000)) which shrinks model.t to 1000 elements,
# but CMC receives the full (1001×1001) C2 — shard t1_idx/t2_idx reach
# index 1000, causing OOB if t_np is taken from the trimmed model.t.
t_np = np.arange(c2_np.shape[0], dtype=np.float64)
contrast, offset = model.scaling.get_for_angle(0)
q_val = float(model.q)
dt_val = float(model.dt)
# Build reparameterization config for shard workers (parity with fit_cmc_jax).
# Workers use reparam_config_dict to sample D0/alpha in log-space, which
# greatly reduces the D0–alpha correlation and improves NUTS acceptance rate.
# Requires an NLSQ warm-start to compute t_ref; falls back to None (raw space).
_use_reparam = (
config.use_reparam and nlsq_result is not None and nlsq_result.success
)
_reparam_config: ReparamConfig | None = None
if _use_reparam:
_t_max_val = float(t_np[-1]) if len(t_np) > 0 else 1.0
_t_ref = compute_t_ref(dt_val, _t_max_val, fallback_value=1.0)
_reparam_config = ReparamConfig(
t_ref=_t_ref,
enable_d_ref=config.reparameterization_d_total,
enable_d_sample=config.reparameterization_d_total,
enable_v_ref=config.reparameterization_log_gamma,
)
logger.info(
"[CMC-sharded] Reparameterization enabled: t_ref=%.4e "
"(d_ref=%s, d_sample=%s, v_ref=%s)",
_t_ref,
config.reparameterization_d_total,
config.reparameterization_d_total,
config.reparameterization_log_gamma,
)
_reparam_config_dict: dict[str, Any] | None = (
{
"enable_d_ref": _reparam_config.enable_d_ref,
"enable_d_sample": _reparam_config.enable_d_sample,
"enable_v_ref": _reparam_config.enable_v_ref,
"t_ref": _reparam_config.t_ref,
}
if _reparam_config is not None
else None
)
# Translate _create_shards output format to the dict format run_shards() expects.
# Two wire formats are supported depending on the sharding strategy:
#
# Element-wise format (random strategy, t1_idx/t2_idx present):
# "t1"/"t2" — per-pair time values; "time_grid" — full axis for ShardGrid
# Worker uses compute_c2_elementwise → 1-D output matching flat c2_data.
#
# Meshgrid format (contiguous strategy, t_indices present):
# "t" — 1-D time axis for the shard block
# Worker uses compute_c2_heterodyne → 2-D output matching square c2_data.
parallel_shards: list[dict[str, Any]] = []
for shard in shards:
_sigma_arr = shard["sigma_shard"]
_sigma_wire = (
np.asarray(_sigma_arr) if not isinstance(_sigma_arr, float) else _sigma_arr
)
_noise_scale = float(
np.mean(np.asarray(_sigma_arr))
if not isinstance(_sigma_arr, float)
else _sigma_arr
)
_base: dict[str, Any] = {
"c2_data": np.asarray(shard["c2_shard"]),
"sigma": _sigma_wire,
"noise_scale": _noise_scale,
"q": q_val,
"dt": dt_val,
"phi_angle": phi_angle,
"contrast": float(contrast),
"offset": float(offset),
"n_phi": 1,
"reparam_config_dict": _reparam_config_dict,
}
if "t1_idx" in shard:
# Element-wise (random): pass paired time values + full axis
_base["t1"] = t_np[shard["t1_idx"]]
_base["t2"] = t_np[shard["t2_idx"]]
_base["time_grid"] = t_np
else:
# Meshgrid (contiguous): pass 1-D sub-axis
_base["t"] = t_np[shard["t_indices"]]
parallel_shards.append(_base)
# NLSQ warm-start values passed to workers for chain initialisation.
initial_values: dict[str, Any] | None = None
nlsq_uncertainties_dict: dict[str, float] | None = None
if nlsq_result is not None and nlsq_result.success:
initial_values = {
name: float(nlsq_result.get_param(name))
for name in nlsq_result.parameter_names
}
# Also pass uncertainties so each worker can build NLSQ-informed
# (TruncatedNormal centered on NLSQ value) priors locally and apply
# CMC tempering on TOP of them. Without this, workers fall back to
# registry defaults and the NLSQ posterior contraction is lost.
unc_dict: dict[str, float] = {}
for name in nlsq_result.parameter_names:
unc = nlsq_result.get_uncertainty(name)
if unc is not None and float(unc) > 0:
unc_dict[name] = float(unc)
nlsq_uncertainties_dict = unc_dict if unc_dict else None
else:
# Fall back to model-configured initial values (e.g. NLSQ results
# pre-populated in the config's `parameters:` section). Without a full
# NLSQResult, NLSQ-informed priors and uncertainties are unavailable,
# so priors remain registry-based — warmup may need more steps.
_varying = set(model.varying_names)
_model_vals = model.get_params_dict()
if _model_vals:
initial_values = {
name: float(v) for name, v in _model_vals.items() if name in _varying
}
# Clamp model-default values away from hard bounds (same protection
# that _clamp_warmstart_to_interior applies to NLSQ results in the CLI
# path). Without clamping, a config like ``parameters: {alpha_ref: -5.0}``
# places the start exactly on the boundary wall, causing NUTS leapfrog
# reflections that degrade BFMI for the whole run.
if initial_values:
_FALLBACK_MARGIN = 5e-2
_iv_clamped: dict[str, float] = {}
for _fname, _fval in initial_values.items():
if _fname in DEFAULT_REGISTRY:
_finfo = DEFAULT_REGISTRY[_fname]
_fspan = _finfo.max_bound - _finfo.min_bound
_flo = _finfo.min_bound + _FALLBACK_MARGIN * _fspan
_fhi = _finfo.max_bound - _FALLBACK_MARGIN * _fspan
_iv_clamped[_fname] = (
float(np.clip(_fval, _flo, _fhi))
if _flo < _fhi
else float(_fval)
)
else:
_iv_clamped[_fname] = float(_fval)
initial_values = _iv_clamped
# Pre-dispatch D_total sign guard (het_c7fb5859 prevention).
# Reparameterisation samples D_total = D0 + D_offset per transport group.
# If D_total ≤ 0 at the warm-start point the prior has log_prob = −∞ →
# every NUTS leapfrog proposal is rejected → BFMI=0.000 and R-hat=NaN
# across all shards. Clamp D_offset so D_total = 1% of D0 (tiny but
# positive) before workers are dispatched.
if _use_reparam and config.reparameterization_d_total and initial_values:
_iv_mutable: dict[str, Any] | None = None # lazy copy-on-write
for _grp in ("ref", "sample"):
_d0 = initial_values.get(f"D0_{_grp}")
_doff = initial_values.get(f"D_offset_{_grp}")
if _d0 is not None and _doff is not None:
_d_total = float(_d0) + float(_doff)
if _d_total <= 0.0:
_new_doff = -0.99 * float(_d0) # D_total = 0.01×D0 > 0
logger.warning(
"[CMC-sharded] D_total_%s = D0_%s + D_offset_%s = %.3g ≤ 0 — "
"reparameterised prior requires D_total > 0 "
"(log-prior = −∞ at warm-start → BFMI=0 on all shards). "
"Clamping D_offset_%s %.3g → %.3g. "
"Root cause: stale NLSQ warm-start with degenerate parameter "
"combination; re-run NLSQ with current config bounds.",
_grp,
_grp,
_grp,
_d_total,
_grp,
float(_doff),
_new_doff,
)
if _iv_mutable is None:
_iv_mutable = dict(initial_values)
_iv_mutable[f"D_offset_{_grp}"] = _new_doff
if _iv_mutable is not None:
initial_values = _iv_mutable
# Degenerate warm-start detector (het_bb97531f failure mode).
# Two compounding conditions cause 100% shard bad_convergence and BFMI=0.000:
# (a) f0 ≈ 0 → sample fraction near-zero → alpha_sample, D0_sample,
# D_offset_sample are unidentifiable from data. Per-shard posteriors
# are dominated by the tempered prior and chains must thermalize over
# the full prior range during warmup with effectively zero likelihood
# gradient for the sample-transport parameter group.
# (b) alpha_sample < -1.5 → J_sample(t) ∝ t^α has a non-integrable
# singularity at t→0 (∫₀^τ t^α dt diverges for α ≤ -1). Even if f0
# is moderate, the NUTS gradient for alpha_sample collapses to zero at
# short lags (exp(-q²·half_tr_sample) → 0) while being huge at the
# first time-grid point → step-size adaptation breaks down.
# Together these guarantee that all 47 shards fail convergence, wasting
# hours of compute. Warn before dispatch so the user can act (e.g. freeze
# degenerate parameters or increase num_warmup) without waiting 7+ hours.
if initial_values:
_f0_iv = initial_values.get("f0")
_alpha_s_iv = initial_values.get("alpha_sample")
_f0_degen = _f0_iv is not None and float(_f0_iv) < CMC_F0_DEGEN_THRESHOLD
_alpha_sing = (
_alpha_s_iv is not None and float(_alpha_s_iv) < CMC_ALPHA_SINGULARITY
)
if _f0_degen or _alpha_sing:
_parts: list[str] = []
if _f0_degen:
_parts.append(
f"f0={float(_f0_iv):.4f} < {CMC_F0_DEGEN_THRESHOLD} — " # type: ignore[arg-type]
"sample fraction near-zero → alpha_sample / D0_sample / "
"D_offset_sample are unidentifiable; posterior ≈ tempered prior"
)
if _alpha_sing:
_parts.append(
f"alpha_sample={float(_alpha_s_iv):.3f} < {CMC_ALPHA_SINGULARITY} — " # type: ignore[arg-type]
"J_sample ∝ t^α has non-integrable singularity at short lags; "
"NUTS step-size collapses for the sample-transport group"
)
logger.warning(
"[CMC-sharded] Degenerate warm-start detected (het_bb97531f failure "
"mode) — convergence is unlikely without intervention:\n %s\n"
"Recommended fixes:\n"
" 1. Freeze the unidentifiable parameters in your YAML config:\n"
" optimization:\n"
" cmc:\n"
" fixed_params:\n"
" alpha_sample: %.3f\n"
" D0_sample: %.3g\n"
" 2. Increase num_warmup to ≥2000 (default 1500 may be insufficient\n"
" when warm-start is far from per-shard posterior mode).\n"
" 3. If f0 < 0.05, consider disabling the sample component\n"
" entirely (fix f0=0) and running a reference-only model.",
";\n ".join(_parts),
float(_alpha_s_iv) if _alpha_s_iv is not None else float("nan"),
float(initial_values.get("D0_sample", float("nan"))),
)
# Soft abort: auto-clamp the offending warm-start values into the
# safe zone instead of raising — the previous hard ``raise
# RuntimeError`` killed the entire run on a single threshold trip
# with no way to recover. The clamp pushes alpha_sample just above
# the t^α singularity and f0 just above the identifiability floor;
# posteriors may still be wide but the run completes and the
# downstream ``min_success_rate`` gate can decide whether the
# output is usable. ``allow_degenerate_warmstart=True`` keeps the
# raw NLSQ values (no clamp), for callers who want to observe the
# full degenerate behaviour (e.g. test harnesses, audits).
if not config.allow_degenerate_warmstart:
_iv_degen_clamped: dict[str, Any] = dict(initial_values)
if _alpha_sing:
_iv_degen_clamped["alpha_sample"] = CMC_ALPHA_SAFE_ZONE
if _f0_degen:
_iv_degen_clamped["f0"] = CMC_F0_SAFE_ZONE
logger.warning(
"[CMC-sharded] Auto-clamping degenerate warm-start to safe "
"zone (alpha_sample=%.3f, f0=%.4f) — set "
"allow_degenerate_warmstart: true to keep raw NLSQ values.",
float(_iv_degen_clamped.get("alpha_sample", float("nan"))),
float(_iv_degen_clamped.get("f0", float("nan"))),
)
initial_values = _iv_degen_clamped
# Log rough runtime estimate before blocking and warn if it exceeds timeout.
avg_pts = sum(int(np.asarray(s["c2_data"]).size) for s in parallel_shards) // max(
num_shards, 1
)
_n_workers = _estimate_n_workers()
_estimated_total = _log_runtime_estimate(
logger,
n_shards=num_shards,
n_chains=config.num_chains,
n_warmup=config.num_warmup,
n_samples=config.num_samples,
avg_points_per_shard=avg_pts,
n_workers=_n_workers,
)
_batches = (num_shards + _n_workers - 1) // _n_workers
_estimated_per_shard = _estimated_total / max(_batches, 1)
if _estimated_per_shard > config.per_shard_timeout:
logger.warning(
"[CMC-sharded] Estimated per-shard time (%.0fs = %.1fh) exceeds "
"per_shard_timeout=%ds. Shards will likely timeout. "
"avg_points_per_shard=%d exceeds the ~100K NUTS limit. "
"Use num_shards='auto' or reduce max_points_per_shard.",
_estimated_per_shard,
_estimated_per_shard / 3600,
config.per_shard_timeout,
avg_pts,
)
# Early abort guard: without any warm-start, NUTS starts from the default
# prior (identity mass matrix) and must discover the posterior geometry
# from scratch during warmup. For the 14-parameter heterodyne model with
# shards >10K points, warmup alone exceeds 7200s because NUTS saturates
# max_tree_depth (1024 leapfrog steps) on every iteration.
# This is the het_c7548ee8 failure mode: all 47 shards timeout with 0
# posterior samples collected.
_NO_NLSQ_SHARD_LIMIT = 10_000
if nlsq_result is None and avg_pts > _NO_NLSQ_SHARD_LIMIT:
if initial_values:
# Model-configured initial values (e.g. NLSQ results pre-populated
# in config's `parameters:` section) provide a good starting point.
# NLSQ-informed priors and uncertainties are unavailable, so the
# mass matrix adapts from identity during warmup — expect more warmup
# steps than with a full NLSQResult, but the run will not timeout.
logger.warning(
"[CMC-sharded] No NLSQ result object provided; using model-configured "
"initial values as fallback warm-start (%d varying params). "
"Priors are registry-based (not NLSQ-informed) — warmup may be slower.",
len(initial_values),
)
else:
# Hard abort: 3 separate runs (het_c7548ee8, het_e34fa942, het_dd0f825b)
# prove that CMC without ANY warm-start on >10K-point shards ALWAYS
# timeouts after 8+ hours with 0 posterior samples. NUTS must discover
# a 14-parameter posterior geometry from scratch (identity mass matrix).
# Warmup alone saturates max_tree_depth=10 (1024 leapfrog steps/step)
# on every iteration, far exceeding per_shard_timeout=7200s.
# Abort immediately to prevent silent 8-hour waste.
raise RuntimeError(
f"[CMC-sharded] Aborting: no NLSQ warm-start provided and "
f"avg_points_per_shard={avg_pts} > {_NO_NLSQ_SHARD_LIMIT}. "
f"Without a warm-start, all {num_shards} shards will timeout "
f"({config.per_shard_timeout}s) with 0 posterior samples collected. "
"Fix: run NLSQ first (optimizer: nlsq) then re-run CMC, or use "
"optimizer: both to run NLSQ→CMC in one pass. "
"To override (e.g. for small pilot runs), reduce max_points_per_shard "
f"below {_NO_NLSQ_SHARD_LIMIT} in the CMC config."
)
# Codex W2: honour config.backend_name instead of hardcoding MP.
# Only the sharded execution backends expose ``run_shards()`` — pjit and
# the single-process CPU backend operate on a single chain at a time and
# are unsuitable for the K-shard CMC pipeline. We surface a clear error
# rather than silently rerouting to MP, so the user sees that their
# explicit choice was unsupported in this path.
_backend_name = (getattr(config, "backend_name", "auto") or "auto").lower()
if _backend_name == "jax":
_backend_name = "multiprocessing" # legacy alias
if _backend_name in ("multiprocessing", "auto", "slurm"):
from heterodyne.optimization.cmc.backends.multiprocessing import (
MultiprocessingBackend,
)
_backend: Any = MultiprocessingBackend()
if _backend_name == "slurm":
logger.warning(
"[CMC-sharded] backend_name='slurm' has no native sharded "
"backend; falling back to MultiprocessingBackend "
"(run from inside the SLURM allocation)"
)
elif _backend_name == "pbs":
from heterodyne.optimization.cmc.backends.pbs import PBSBackend
_backend = PBSBackend()
elif _backend_name in ("cpu", "pjit"):
raise ValueError(
f"backend_name={_backend_name!r} cannot drive the sharded "
"CMC pipeline (no run_shards() method). Use "
"'multiprocessing' (default), 'pbs', or 'auto'."
)
else:
raise ValueError(
f"Unknown backend_name={_backend_name!r}. Valid options for the "
"sharded CMC pipeline: 'multiprocessing', 'pbs', 'auto', 'slurm', "
"'jax' (legacy alias for 'multiprocessing')."
)
logger.info(
"[CMC-sharded] Phase 3/5: dispatching %d shards to %s "
"(backend_name=%r from CMCConfig)",
num_shards,
type(_backend).__name__,
_backend_name,
)
raw_results = _backend.run_shards(
shards=parallel_shards,
config=config,
initial_values=initial_values,
parameter_space=_space,
prior_width_multiplier=prior_width_mult,
nlsq_uncertainties=nlsq_uncertainties_dict
if config.use_nlsq_informed_priors
else None,
nlsq_prior_width_factor=float(config.nlsq_prior_width_factor),
progress_bar=True,
)
# Convert worker result dicts → CMCResult objects for _combine_shard_posteriors().
shard_results: list[CMCResult] = [
_result_dict_to_cmc_result(r, config) for r in raw_results
]
# Pad with failed placeholders for any shards dropped by run_shards() (timeout/crash).
if len(shard_results) < num_shards:
n_missing = num_shards - len(shard_results)
logger.warning(
"[CMC-sharded] %d/%d shards failed or timed out",
n_missing,
num_shards,
)
# Workers report ``param_names = varying_physics_names`` (14 names,
# physics only). ``_space.varying_names`` may include the 2 scaling
# params, which would create a shape mismatch downstream when the
# combined CMCResult is consumed. Fall back to the physics-only list.
_fallback_names: list[str] = (
list(raw_results[0]["param_names"])
if raw_results
else list(_space.varying_physics_names)
)
for _ in range(n_missing):
shard_results.append(
_create_failed_result(_fallback_names, "shard failed or timed out")
)
# Emit an aggregate convergence summary at WARNING level so failures are
# visible in production logs without requiring --debug / log_level: DEBUG.
# Per-shard detail (R-hat, ESS, BFMI per parameter) remains at DEBUG to
# avoid flooding the log with up to 47 lines per angle.
_n_bad_shards = sum(1 for sr in shard_results if not sr.convergence_passed)
if _n_bad_shards > 0:
_first_bad = next(
(sr for sr in shard_results if not sr.convergence_passed), None
)
if _first_bad is not None:
_fb_rh = _first_bad.r_hat
_fb_ess = _first_bad.ess_bulk
_fb_bfmi = _first_bad.bfmi
# Suppress the "All-NaN slice encountered" RuntimeWarning that
# numpy emits when ``_fb_rh`` is all-NaN (single-chain shards,
# all-divergent posteriors). NaN propagation is the intended
# behaviour here — the diagnostic just reports it downstream.
with np.errstate(invalid="ignore"), warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=RuntimeWarning,
message="All-NaN slice encountered",
)
_fb_max_rhat = (
float(np.nanmax(_fb_rh))
if _fb_rh is not None and len(_fb_rh) > 0
else float("nan")
)
_fb_min_ess = (
float(np.nanmin(_fb_ess))
if _fb_ess is not None and len(_fb_ess) > 0
else float("nan")
)
_fb_min_bfmi = (
float(np.nanmin(np.asarray(_fb_bfmi, dtype=float)))
if _fb_bfmi is not None
else float("nan")
)
logger.warning(
"[CMC-sharded] %d/%d shards failed convergence. "
"First failing shard: max_r_hat=%.3f (threshold=%.2f), "
"min_ess=%.0f (threshold=%d), min_bfmi=%.3f. "
"Run with log_level: DEBUG (or --debug) to see per-shard details.",
_n_bad_shards,
len(shard_results),
_fb_max_rhat,
config.max_r_hat,
_fb_min_ess,
config.min_ess,
_fb_min_bfmi,
)
# --- Phase 4: consensus combination ---
logger.info("[CMC-sharded] Phase 4/5: combining shard posteriors (consensus)")
combined_result = _combine_shard_posteriors(
shard_results,
config,
num_shards,
base_seed,
)
# --- Bimodal detection across shards ---
# Run after combination so the combine path's Gaussian approximation
# can be checked for mode collapse. Results stored in metadata only —
# the caller gets a normal CMCResult but can inspect metadata["bimodal"].
bimodal_metadata: dict[str, Any] = {}
successful_with_samples = [
sr for sr in shard_results if sr.convergence_passed and sr.samples is not None
]
if len(successful_with_samples) >= 2:
from heterodyne.optimization.cmc.diagnostics import check_shard_bimodality
# successful_with_samples is filtered so sr.samples is never None;
# the cast narrows the comprehension type for Pyright.
shard_sample_dict: dict[int, dict[str, np.ndarray]] = {
i: cast("dict[str, np.ndarray]", sr.samples)
for i, sr in enumerate(successful_with_samples)
}
bimodal_results = check_shard_bimodality(
shard_sample_dict,
min_weight=config.bimodal_min_weight,
min_separation=config.bimodal_min_separation,
)
bimodal_params = [
p for p, rs in bimodal_results.items() if any(r.is_bimodal for r in rs)
]
bimodal_metadata["bimodal_detected"] = len(bimodal_params) > 0
bimodal_metadata["bimodal_params"] = bimodal_params
if bimodal_params:
logger.warning(
"[CMC-sharded] Bimodal posteriors detected for %d parameters: %s. "
"Gaussian consensus approximation may be inaccurate.",
len(bimodal_params),
bimodal_params,
)
# --- Phase 5: finalize ---
wall_time = time.perf_counter() - start_time
logger.info(
"[CMC-sharded] Phase 5/5: finalizing (total wall time=%.1fs)", wall_time
)
# Attach per-shard diagnostics to metadata
shard_diagnostics = [
{
"convergence_passed": r.convergence_passed,
"r_hat": r.r_hat.tolist() if r.r_hat is not None else None,
"ess_bulk": r.ess_bulk.tolist() if r.ess_bulk is not None else None,
"bfmi": r.bfmi,
"wall_time_seconds": r.wall_time_seconds,
}
for r in shard_results
]
metadata = dict(combined_result.metadata)
metadata["num_shards"] = num_shards
metadata["sharding_strategy"] = sharding_strategy
metadata["shard_seed"] = effective_seed
metadata["shard_diagnostics"] = shard_diagnostics
metadata["n_failed_shards"] = sum(
1 for r in shard_results if not r.convergence_passed
)
# BFMI summary across all shards — BFMI=0.000 on all shards is the
# fingerprint of the het_bb97531f degenerate warm-start failure mode.
# Storing mean_shard_bfmi and n_bfmi_zero enables post-hoc diagnosis
# without re-running with DEBUG logging.
_all_bfmi_flat = [b for r in shard_results if r.bfmi is not None for b in r.bfmi]
metadata["mean_shard_bfmi"] = (
float(np.nanmean(np.asarray(_all_bfmi_flat, dtype=float)))
if _all_bfmi_flat
else None
)
metadata["n_bfmi_zero"] = sum(
1
for r in shard_results
if r.bfmi is not None
and float(np.nanmin(np.asarray(r.bfmi, dtype=float))) < 0.01
)
metadata.update(bimodal_metadata)
final = CMCResult(
parameter_names=combined_result.parameter_names,
posterior_mean=combined_result.posterior_mean,
posterior_std=combined_result.posterior_std,
credible_intervals=combined_result.credible_intervals,
convergence_passed=combined_result.convergence_passed,
r_hat=combined_result.r_hat,
ess_bulk=combined_result.ess_bulk,
ess_tail=combined_result.ess_tail,
bfmi=combined_result.bfmi,
samples=combined_result.samples,
map_estimate=combined_result.map_estimate,
num_warmup=config.num_warmup,
num_samples=config.num_samples * num_shards,
num_chains=config.num_chains,
wall_time_seconds=wall_time,
metadata=metadata,
)
logger.info(
"[CMC-sharded] Complete in %.1fs, convergence: %s, failed shards: %d/%d",
wall_time,
"PASSED" if final.convergence_passed else "FAILED",
metadata["n_failed_shards"],
num_shards,
)
return final
# ---------------------------------------------------------------------------
# Shard creation
# ---------------------------------------------------------------------------
def _create_shards(
c2_np: np.ndarray,
sigma_np: np.ndarray | float,
num_shards: int,
strategy: str,
seed: int,
) -> list[dict[str, Any]]:
"""Partition correlation data into shards for Consensus Monte Carlo.
Two strategies are supported:
- ``"random"``: randomly shuffles the flat index set of the upper
triangle (including diagonal), then cuts into equal-sized groups.
Each shard receives a sub-matrix assembled from its assigned
element indices. This breaks temporal autocorrelation across
shards.
- ``"contiguous"``: partitions the time axis into equal-width
contiguous blocks and takes the diagonal sub-matrix for each
block. Preserves the two-time structure within each shard.
In both cases the per-shard sigma is sliced to match the shard shape.
Args:
c2_np: Full two-time correlation matrix, shape (N, N), float64.
sigma_np: Uncertainty array of the same shape as ``c2_np``, or a
scalar float. Scalar sigma is broadcast per shard.
num_shards: Number of partitions K.
strategy: ``"random"`` or ``"contiguous"``.
seed: Integer seed for reproducible random shard assignment.
Returns:
List of K dicts, each containing:
- ``"c2_shard"``: JAX array of shape (n_shard_times, n_shard_times)
or (n_elements,) depending on strategy.
- ``"sigma_shard"``: matching uncertainty array or scalar.
- ``"indices"``: 1-D NumPy array of flat matrix indices assigned
to this shard (for auditing and reconstruction).
Raises:
ValueError: If ``strategy`` is not ``"random"`` or ``"contiguous"``.
"""
if strategy not in {"random", "contiguous"}:
raise ValueError(
f"sharding_strategy must be 'random' or 'contiguous', got '{strategy}'"
)
n = c2_np.shape[0]
sigma_is_scalar = isinstance(sigma_np, float) or (
isinstance(sigma_np, np.ndarray) and sigma_np.ndim == 0
)
if strategy == "random":
return _create_shards_random(
c2_np, sigma_np, sigma_is_scalar, num_shards, seed, n
)
# contiguous: diagonal blocks along the time axis
return _create_shards_contiguous(c2_np, sigma_np, sigma_is_scalar, num_shards, n)
def _create_shards_random(
c2_np: np.ndarray,
sigma_np: np.ndarray | float,
sigma_is_scalar: bool,
num_shards: int,
seed: int,
n: int,
) -> list[dict[str, Any]]:
"""Random element-wise sharding — flat per-pair representation.
Each shard contains a 1-D array of selected c2 values together with their
(row, col) time indices. The worker builds a ShardGrid from these indices
and calls ``compute_c2_elementwise``, which avoids the O(N²) meshgrid
allocation and produces a 1-D prediction that matches the flat c2 data.
Previous implementation reconstructed a (shard_n, shard_n) zero-padded
sub-matrix, which caused two bugs:
1. Shape metadata was lost during shared-memory serialisation, producing
a (1002001,) flat array in the worker instead of (1001, 1001).
2. ``compute_c2_heterodyne`` (NLSQ meshgrid path) returned (N, N) while
the obs array had the wrong shape → BroadcastError.
"""
rng = np.random.default_rng(seed)
# All N² flat indices (symmetric matrix — use every pair, not just triu)
all_indices = np.arange(n * n, dtype=np.int64)
rng.shuffle(all_indices)
splits = np.array_split(all_indices, num_shards)
shards: list[dict[str, Any]] = []
for split_indices in splits:
rows, cols = np.divmod(split_indices, n)
c2_flat = c2_np[rows, cols]
sigma_flat: np.ndarray | float = (
float(sigma_np) if sigma_is_scalar else np.asarray(sigma_np)[rows, cols] # type: ignore[arg-type]
)
shards.append(
{
"c2_shard": c2_flat, # shape (n_elements,)
"sigma_shard": sigma_flat,
"indices": split_indices,
"t1_idx": rows.astype(np.int64), # row time indices
"t2_idx": cols.astype(np.int64), # col time indices
}
)
return shards
def _create_shards_contiguous(
c2_np: np.ndarray,
sigma_np: np.ndarray | float,
sigma_is_scalar: bool,
num_shards: int,
n: int,
) -> list[dict[str, Any]]:
"""Contiguous diagonal-block sharding along the time axis."""
boundaries = np.linspace(0, n, num_shards + 1, dtype=int)
shards: list[dict[str, Any]] = []
for i in range(num_shards):
start = int(boundaries[i])
stop = int(boundaries[i + 1])
c2_block = c2_np[start:stop, start:stop]
if sigma_is_scalar:
sigma_block: np.ndarray | float = float(sigma_np) # type: ignore[arg-type]
else:
sigma_block = np.asarray(sigma_np)[start:stop, start:stop]
# Flat index set for this diagonal block
row_idx, col_idx = np.meshgrid(
np.arange(start, stop), np.arange(start, stop), indexing="ij"
)
flat_indices = (row_idx * n + col_idx).ravel().astype(np.int64)
shards.append(
{
"c2_shard": jnp.asarray(c2_block),
"sigma_shard": sigma_block,
"indices": flat_indices,
"t_indices": np.arange(start, stop, dtype=np.int64),
}
)
return shards
# ---------------------------------------------------------------------------
# Posterior combination
# ---------------------------------------------------------------------------
def _combine_shard_posteriors(
shard_results: list[CMCResult],
config: CMCConfig,
num_shards: int,
base_seed: int,
) -> CMCResult:
"""Combine per-shard posteriors via inverse-variance weighted consensus.
Implements the Consensus Monte Carlo estimator (Scott et al., 2016):
.. math::
\\mu^* = \\left(\\sum_k \\Sigma_k^{-1}\\right)^{-1}
\\sum_k \\Sigma_k^{-1} \\mu_k
where :math:`\\mu_k` and :math:`\\Sigma_k` are the per-shard posterior
mean and (diagonal) variance. The combined variance is:
.. math::
\\Sigma^* = \\left(\\sum_k \\Sigma_k^{-1}\\right)^{-1}
The worst-case R-hat across shards is used as the combined convergence
diagnostic. Combined ESS is the sum of per-shard ESS values (an
approximation; true combined ESS would require cross-shard
autocorrelation analysis).
Args:
shard_results: List of CMCResult objects, one per shard. All
must share the same ``parameter_names``.
config: Global CMC config for convergence thresholds.
num_shards: Number of shards (used only for logging).
base_seed: Base random seed (not used here; reserved for future
importance-resampling extension).
Returns:
CMCResult with combined posterior statistics.
Raises:
ValueError: If ``shard_results`` is empty or parameter names
are inconsistent across shards.
"""
if not shard_results:
raise ValueError("shard_results must be non-empty")
param_names = shard_results[0].parameter_names
for i, sr in enumerate(shard_results[1:], start=1):
if sr.parameter_names != param_names:
raise ValueError(
f"Shard {i} parameter_names mismatch: "
f"expected {param_names}, got {sr.parameter_names}"
)
n_params = len(param_names)
# --- Filter: exclude failed/degenerate/high-divergence shards ---
_max_div_rate = getattr(config, "max_divergence_rate", 0.10)
def _shard_has_valid_samples(sr: CMCResult) -> bool:
# Per-shard inclusion gate: keep the shard if at least one parameter
# has a finite positive std. Per-parameter masking inside the
# consensus loops then excludes the specific NaN/zero entries
# without dropping the whole shard's contribution (codex W1).
std = np.asarray(sr.posterior_std)
return bool(np.any(np.isfinite(std) & (std > 0)))
def _shard_diagnostics_unknown(sr: CMCResult) -> bool:
# r_hat all-NaN means ArviZ failed to build InferenceData (e.g. API mismatch)
# but NUTS samples were collected — accept the shard on raw-sample basis.
return sr.r_hat is not None and bool(np.all(np.isnan(sr.r_hat)))
def _shard_divergence_ok(sr: CMCResult) -> bool:
return getattr(sr, "metadata", {}).get("divergence_rate", 0.0) <= _max_div_rate
def _shard_included(sr: CMCResult) -> bool:
# Single source of truth for "did this shard contribute to consensus?".
# ``successful`` and ``_failure_mask`` MUST agree on this predicate;
# deriving the mask from a weaker check (valid-samples only) would mark
# a high-divergence or non-converged shard as healthy even though it was
# dropped from the combined posterior (Codex finding).
return bool(
_shard_has_valid_samples(sr)
and _shard_divergence_ok(sr)
and (sr.convergence_passed or _shard_diagnostics_unknown(sr))
)
# Gemini P2-d / Codex: expose per-shard failure mask so callers can identify
# WHICH shards dropped, not just how many. Mask is True at index i iff shard
# i was excluded from consensus — by the SAME predicate that builds
# ``successful``. Typed reason masks below let downstream tooling see why.
_failure_mask: np.ndarray = np.array(
[not _shard_included(sr) for sr in shard_results],
dtype=bool,
)
_no_samples_mask: np.ndarray = np.array(
[not _shard_has_valid_samples(sr) for sr in shard_results],
dtype=bool,
)
_high_divergence_mask: np.ndarray = np.array(
[
_shard_has_valid_samples(sr) and not _shard_divergence_ok(sr)
for sr in shard_results
],
dtype=bool,
)
_bad_convergence_mask: np.ndarray = np.array(
[
_shard_has_valid_samples(sr)
and _shard_divergence_ok(sr)
and not (sr.convergence_passed or _shard_diagnostics_unknown(sr))
for sr in shard_results
],
dtype=bool,
)
successful = [sr for sr in shard_results if _shard_included(sr)]
n_diag_unknown = sum(1 for sr in successful if not sr.convergence_passed)
if n_diag_unknown > 0:
logger.warning(
"_combine_shard_posteriors: %d/%d accepted shards have unknown convergence "
"(ArviZ diagnostics unavailable); combining on raw-sample basis",
n_diag_unknown,
len(successful),
)
if not successful:
_n_no_samples = sum(
1 for sr in shard_results if not _shard_has_valid_samples(sr)
)
_n_high_div = sum(
1
for sr in shard_results
if _shard_has_valid_samples(sr)
and getattr(sr, "metadata", {}).get("divergence_rate", 0.0) > _max_div_rate
)
_div_rates = [
getattr(sr, "metadata", {}).get("divergence_rate")
for sr in shard_results
if _shard_has_valid_samples(sr)
]
_finite_rates = [r for r in _div_rates if r is not None]
_mean_div = float(np.mean(_finite_rates)) if _finite_rates else float("nan")
logger.error(
"_combine_shard_posteriors: all %d shards failed "
"(no_samples=%d, high_divergence[>%.0f%%]=%d, bad_convergence=%d, "
"mean_divergence_rate=%.1f%%) — "
"if divergence is high, a warm-start parameter may be at a hard bound; "
"returning degenerate result",
len(shard_results),
_n_no_samples,
_max_div_rate * 100,
_n_high_div,
len(shard_results) - _n_no_samples - _n_high_div,
_mean_div * 100,
)
return CMCResult(
parameter_names=param_names,
posterior_mean=np.zeros(n_params),
posterior_std=np.full(n_params, np.nan),
credible_intervals={},
convergence_passed=False,
r_hat=np.full(n_params, np.nan),
ess_bulk=np.full(n_params, np.nan),
ess_tail=np.full(n_params, np.nan),
bfmi=None,
samples=None,
map_estimate=None,
num_warmup=shard_results[0].num_warmup,
num_samples=shard_results[0].num_samples,
num_chains=shard_results[0].num_chains,
wall_time_seconds=None,
metadata={
"all_shards_failed": True,
"n_total_shards": num_shards,
# P2-d: mask length tracks ``len(shard_results)`` (the
# padded list the consumer sees) so callers can iterate
# the mask in lockstep with ``shard_results`` without
# worrying about length skew. No shard was included, so the
# unified mask is all-True; the typed reason masks still
# attribute *why* each shard dropped.
"failure_mask": np.full(len(shard_results), True, dtype=bool),
"no_samples_mask": _no_samples_mask,
"high_divergence_mask": _high_divergence_mask,
"bad_convergence_mask": _bad_convergence_mask,
},
)
n_skipped = len(shard_results) - len(successful)
if n_skipped > 0:
logger.warning(
"_combine_shard_posteriors: skipping %d failed/high-divergence shards "
"(%d/%d successful remain)",
n_skipped,
len(successful),
len(shard_results),
)
# --- Heterogeneity check: IQR-based CV, robust to near-zero parameters ---
# α_ref, β, v_offset, φ₀ all default to ~0; std/|mean| diverges at zero.
# IQR / max(|median|, 1e-3) stays finite and comparable across all params.
if len(successful) >= 2:
_smeans = np.stack([sr.posterior_mean for sr in successful], axis=0)
_q75, _q25 = np.percentile(_smeans, [75, 25], axis=0)
_iqr = _q75 - _q25
_denom = np.maximum(np.abs(np.median(_smeans, axis=0)), 1e-3)
_cv_robust = _iqr / _denom
_max_cv_actual = float(np.max(_cv_robust))
_cfg_max_cv = getattr(config, "max_parameter_cv", 1.0)
if _max_cv_actual > _cfg_max_cv:
_worst = param_names[int(np.argmax(_cv_robust))]
_msg = (
f"High cross-shard heterogeneity: max IQR-CV={_max_cv_actual:.2f} "
f"(threshold {_cfg_max_cv}) on parameter {_worst!r}. "
"Consider increasing min_points_per_shard or using NLSQ warm-start."
)
if getattr(config, "heterogeneity_abort", False):
raise RuntimeError(_msg)
logger.warning("_combine_shard_posteriors: %s", _msg)
combination_method = (
getattr(config, "combination_method", "consensus_mc") or "consensus_mc"
)
_known_methods = frozenset(
{"consensus_mc", "simple_average", "robust_consensus_mc", "weighted_gaussian"}
)
if combination_method not in _known_methods:
logger.warning(
"_combine_shard_posteriors: unknown combination_method %r; "
"falling back to inverse-variance (consensus_mc).",
combination_method,
)
if combination_method == "simple_average":
# Equal-weight mean and variance across shards
combined_mean = np.mean(
np.stack([sr.posterior_mean for sr in successful], axis=0), axis=0
)
combined_var = np.mean(
np.stack([sr.posterior_std**2 for sr in successful], axis=0), axis=0
)
combined_std = np.sqrt(combined_var)
elif combination_method == "robust_consensus_mc" and len(successful) >= 2:
# Per-parameter z-score outlier detection before inverse-variance combination.
# scale = max(std, 1e-4*(|mean|+1)) keeps near-zero params (α_ref, β,
# v_offset, φ₀ ≈ 0) from producing infinite z-scores.
_smeans = np.stack([sr.posterior_mean for sr in successful], axis=0) # (K,P)
_center = np.mean(_smeans, axis=0)
_scale = np.maximum(
np.std(_smeans, axis=0),
1e-4 * (np.abs(_center) + 1.0),
)
_z_max = np.max(np.abs(_smeans - _center) / _scale, axis=1) # (K,)
_inlier = _z_max <= 3.0
_n_excl = int(np.sum(~_inlier))
if _n_excl > 0:
logger.warning(
"_combine_shard_posteriors: robust_consensus_mc excluded "
"%d/%d outlier shards (z-score > 3)",
_n_excl,
len(successful),
)
_pool = [
sr for sr, keep in zip(successful, _inlier.tolist(), strict=False) if keep
] or successful
weight_sum = np.zeros(n_params)
weighted_mean_sum = np.zeros(n_params)
for sr in _pool:
# Codex W1: per-parameter mask. Non-finite / non-positive std
# in one parameter slot must NOT taint the other parameters'
# consensus, and must NOT trigger 1/0 RuntimeWarnings. The
# nested ``np.where`` shields the divide from evaluating the
# true branch on invalid entries (outer where already drops
# them from the sum but numpy still evaluates both branches).
std_k = np.asarray(sr.posterior_std)
valid = np.isfinite(std_k) & (std_k > 0)
safe_std = np.where(valid, std_k, 1.0)
w_k = np.where(valid, 1.0 / (safe_std**2), 0.0)
weight_sum += w_k
weighted_mean_sum += w_k * np.where(valid, sr.posterior_mean, 0.0)
combined_mean = weighted_mean_sum / np.where(weight_sum > 0, weight_sum, 1.0)
combined_var = np.where(
weight_sum > 0, 1.0 / np.where(weight_sum > 0, weight_sum, 1.0), np.nan
)
combined_std = np.sqrt(combined_var)
else:
# Default: inverse-variance weighting (consensus_mc / fallback)
# Weight_k = 1 / Var_k (per-parameter, diagonal approximation)
weight_sum = np.zeros(n_params)
weighted_mean_sum = np.zeros(n_params)
for sr in successful:
# Codex W1: per-parameter mask (see robust_consensus_mc branch).
std_k = np.asarray(sr.posterior_std)
valid = np.isfinite(std_k) & (std_k > 0)
safe_std = np.where(valid, std_k, 1.0)
w_k = np.where(valid, 1.0 / (safe_std**2), 0.0)
weight_sum += w_k
weighted_mean_sum += w_k * np.where(valid, sr.posterior_mean, 0.0)
combined_mean = weighted_mean_sum / np.where(weight_sum > 0, weight_sum, 1.0)
combined_var = np.where(
weight_sum > 0, 1.0 / np.where(weight_sum > 0, weight_sum, 1.0), np.nan
)
combined_std = np.sqrt(combined_var)
# --- Worst-case R-hat (conservative) ---
r_hat_arrays = [sr.r_hat for sr in successful if sr.r_hat is not None]
combined_r_hat = (
np.nanmax(np.stack(r_hat_arrays, axis=0), axis=0)
if r_hat_arrays
else np.full(n_params, np.nan)
)
# --- Summed ESS (approximate) ---
ess_bulk_arrays = [sr.ess_bulk for sr in successful if sr.ess_bulk is not None]
combined_ess_bulk = (
np.nansum(np.stack(ess_bulk_arrays, axis=0), axis=0)
if ess_bulk_arrays
else np.full(n_params, np.nan)
)
ess_tail_arrays = [sr.ess_tail for sr in successful if sr.ess_tail is not None]
combined_ess_tail = (
np.nansum(np.stack(ess_tail_arrays, axis=0), axis=0)
if ess_tail_arrays
else np.full(n_params, np.nan)
)
# --- BFMI: minimum across all shards and all chains ---
all_bfmi_values: list[float] = []
for sr in successful:
if sr.bfmi is not None:
all_bfmi_values.extend(sr.bfmi)
combined_bfmi = all_bfmi_values if all_bfmi_values else None
# --- Credible intervals from combined samples ---
# Pool samples across shards for each parameter
combined_samples: dict[str, np.ndarray] = {}
if all(sr.samples is not None for sr in successful):
for name in param_names:
arrays = [
np.asarray(sr.samples[name]) # type: ignore[index]
for sr in successful
if sr.samples is not None and name in sr.samples
]
if arrays:
combined_samples[name] = np.concatenate(arrays, axis=0)
credible_intervals: dict[str, dict[str, float]] = {}
for name in param_names:
if name in combined_samples:
s = combined_samples[name]
z95 = float(np.percentile(s, 97.5))
l95 = float(np.percentile(s, 2.5))
z89 = float(np.percentile(s, 94.5))
l89 = float(np.percentile(s, 5.5))
credible_intervals[name] = {
"lower_95": l95,
"upper_95": z95,
"lower_89": l89,
"upper_89": z89,
}
# --- MAP estimate ---
map_estimate = combined_mean.copy()
# --- Convergence gate ---
# Only evaluate diagnostics over the successful shards (failed shards are
# already excluded; any skipped shard is reflected in n_skipped above).
r_hat_finite = combined_r_hat[~np.isnan(combined_r_hat)]
ess_finite = combined_ess_bulk[~np.isnan(combined_ess_bulk)]
# Gate 1: per-shard R-hat / ESS on the survivors.
_diagnostics_passed = bool(
len(successful) > 0
and len(r_hat_finite) > 0
and np.all(r_hat_finite < config.max_r_hat)
and len(ess_finite) > 0
and np.all(ess_finite > config.min_ess)
)
# Gate 2 (Codex C1 / heterodyne min_success_rate): the *combined* result
# cannot be declared "converged" unless enough shards actually contributed.
# Previously, a single survivor with clean R-hat could mark the run passed,
# even with 46/47 shards timed out.
_success_rate = len(successful) / num_shards if num_shards > 0 else 0.0
_rate_passed = _success_rate >= float(getattr(config, "min_success_rate", 0.90))
convergence_passed = bool(_diagnostics_passed and _rate_passed)
if _diagnostics_passed and not _rate_passed:
logger.warning(
"_combine_shard_posteriors: diagnostics passed but only %d/%d "
"shards succeeded (rate=%.2f < min_success_rate=%.2f) — "
"marking combined result NOT converged.",
len(successful),
num_shards,
_success_rate,
float(getattr(config, "min_success_rate", 0.90)),
)
# BFMI is advisory for combined result (homodyne parity).
# Low combined BFMI is a useful warning but not a hard gate — the
# combined R-hat and ESS from pooled shards are the authoritative signal.
if combined_bfmi is not None:
_combined_min_bfmi = float(np.nanmin(np.asarray(combined_bfmi, dtype=float)))
if _combined_min_bfmi < config.min_bfmi:
logger.warning(
"_combine_shard_posteriors: low combined BFMI=%.3f < %.2f "
"(advisory; does not override R-hat/ESS convergence gate). "
"Consider reparameterization or wider parameter bounds.",
_combined_min_bfmi,
config.min_bfmi,
)
logger.info(
"[CMC-sharded] Consensus combination: %d/%d shards converged "
"(%d skipped/failed), worst_rhat=%.3f, combined_ess_min=%.0f",
len(successful),
num_shards,
n_skipped,
float(np.nanmax(combined_r_hat)) if combined_r_hat.size > 0 else float("nan"),
float(np.nanmin(combined_ess_bulk))
if combined_ess_bulk.size > 0
else float("nan"),
)
return CMCResult(
parameter_names=param_names,
posterior_mean=combined_mean,
posterior_std=combined_std,
credible_intervals=credible_intervals,
convergence_passed=convergence_passed,
r_hat=combined_r_hat,
ess_bulk=combined_ess_bulk,
ess_tail=combined_ess_tail,
bfmi=combined_bfmi,
samples=combined_samples if combined_samples else None,
map_estimate=map_estimate,
num_warmup=shard_results[0].num_warmup,
num_samples=shard_results[0].num_samples,
num_chains=shard_results[0].num_chains,
wall_time_seconds=None, # caller fills this in
metadata={
"n_total_shards": num_shards,
"n_successful_shards": len(successful),
"success_rate": _success_rate,
"diagnostics_passed": _diagnostics_passed,
"rate_passed": _rate_passed,
# P2-d / Codex: per-shard bool mask (True = excluded from
# consensus). Length = len(shard_results); entry i == True iff
# shard i failed the SAME inclusion predicate as ``successful``
# (no valid samples, OR high divergence, OR failed convergence).
"failure_mask": _failure_mask,
"no_samples_mask": _no_samples_mask,
"high_divergence_mask": _high_divergence_mask,
"bad_convergence_mask": _bad_convergence_mask,
},
)
# ---------------------------------------------------------------------------
# Input validation
# ---------------------------------------------------------------------------
def _validate_cmc_inputs(
c2_data: np.ndarray | jnp.ndarray,
sigma: np.ndarray | float | None,
space: Any,
) -> None:
"""Validate inputs before starting any CMC analysis.
Checks performed:
1. ``c2_data`` is 2-D and square (the heterodyne two-time matrix
is always N x N).
2. ``c2_data`` does not contain NaN or Inf.
3. ``c2_data`` is approximately symmetric:
``max |c2 - c2.T| / max |c2| < 1e-3``.
4. If ``sigma`` is provided as an array, it is strictly positive
and has no NaN values.
5. The parameter space has at least one varying parameter.
Args:
c2_data: Observed correlation matrix.
sigma: Measurement uncertainty, or ``None``.
space: ParameterSpace object with ``varying_names`` and ``bounds``.
Raises:
ValueError: On any validation failure.
"""
c2_np = np.asarray(c2_data)
# 1. Shape
if c2_np.ndim != 2:
raise ValueError(
f"c2_data must be 2-D for CMC analysis, got {c2_np.ndim}-D "
f"with shape {c2_np.shape}"
)
if c2_np.shape[0] != c2_np.shape[1]:
raise ValueError(
f"c2_data must be square (heterodyne two-time matrix), "
f"got shape {c2_np.shape}"
)
# 2. NaN / Inf
n_nan = int(np.sum(np.isnan(c2_np)))
if n_nan > 0:
raise ValueError(f"c2_data contains {n_nan} NaN values; clean data before CMC")
n_inf = int(np.sum(np.isinf(c2_np)))
if n_inf > 0:
raise ValueError(f"c2_data contains {n_inf} Inf values; clean data before CMC")
# 3. Approximate symmetry
max_abs = float(np.max(np.abs(c2_np)))
if max_abs > 0:
asymmetry = float(np.max(np.abs(c2_np - c2_np.T))) / max_abs
if asymmetry > 1e-3:
raise ValueError(
f"c2_data is not approximately symmetric: "
f"max |c2 - c2.T| / max |c2| = {asymmetry:.4e} > 1e-3. "
"The heterodyne two-time matrix must be symmetric."
)
# 4. Sigma
if sigma is not None and not isinstance(sigma, float):
sigma_np = np.asarray(sigma)
if np.any(sigma_np <= 0):
raise ValueError("sigma array must be strictly positive everywhere")
n_nan_sigma = int(np.sum(np.isnan(sigma_np)))
if n_nan_sigma > 0:
raise ValueError(f"sigma contains {n_nan_sigma} NaN values")
# 5. Parameter space
if not hasattr(space, "varying_names") or len(space.varying_names) == 0:
raise ValueError(
"Parameter space has no varying parameters; "
"at least one parameter must be free for CMC"
)
logger.debug(
"[CMC] Input validation passed: shape=%s, n_varying=%d",
c2_np.shape,
len(space.varying_names),
)
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
def _make_shard_config(config: CMCConfig, seed: int) -> CMCConfig:
"""Return a copy of ``config`` with a new seed.
All other fields are preserved exactly. The shard-level config
intentionally keeps the same warmup/sample counts and NUTS hyper-
parameters as the full-data config.
Args:
config: Original CMCConfig.
seed: New integer seed for this shard's sampling run.
Returns:
New CMCConfig instance with ``seed`` replaced.
"""
import dataclasses
return dataclasses.replace(config, seed=seed)
def _estimate_n_workers() -> int:
"""Estimate the number of parallel workers the MultiprocessingBackend will use."""
import multiprocessing as _mp
try:
logical = _mp.cpu_count() or 1
except NotImplementedError:
logical = 4
return max(1, logical // 2 - 1)
#: Conservative resident-memory baseline for one spawned CMC worker (bytes).
#: Each worker is a fresh ``spawn`` process that imports JAX + XLA and builds
#: its own compilation cache; for the ~10K-point shards the auto-sharder
#: produces, this fixed cost dominates the per-shard data, so the worker pool
#: must be bounded by RAM, not just CPU count.
_WORKER_BASELINE_BYTES: int = 2_000_000_000
#: Fraction of *available* RAM the concurrent worker pool may collectively use.
_WORKER_MEMORY_FRACTION: float = 0.8
#: Autodiff/temporary overhead multiplier on a shard's device arrays. NUTS
#: reverse-mode through the element-wise kernel holds several data-length
#: intermediates (value, grad, cumsum/gather temporaries) per leapfrog.
_SHARD_AD_OVERHEAD: int = 8
def _available_memory_bytes() -> int | None:
"""Available system memory in bytes, or ``None`` if it can't be determined."""
try:
import psutil
return int(psutil.virtual_memory().available)
except ImportError:
pass
try: # Linux fallback: available pages * page size.
return int(os.sysconf("SC_AVPHYS_PAGES") * os.sysconf("SC_PAGE_SIZE"))
except (ValueError, OSError, AttributeError):
return None
def _estimate_shard_peak_bytes(payload: dict[str, Any], num_chains: int) -> int:
"""Rough peak working set of one shard's NUTS run (excludes worker baseline).
Sums the shard's device arrays (data + per-point grid indices + time grid),
scaled by an autodiff/chain overhead factor.
"""
arr_bytes = 0
for key in ("data", "i1_indices", "i2_indices", "time_grid", "phi_indices"):
arr = payload.get(key)
if arr is not None:
arr_bytes += int(np.asarray(arr).nbytes)
return arr_bytes * _SHARD_AD_OVERHEAD * max(1, num_chains)
def _memory_aware_worker_cap(
payloads: list[dict[str, Any]], cpu_workers: int, num_chains: int
) -> int:
"""Cap worker count so concurrent shard NUTS runs fit in available RAM.
Returns the largest ``W <= cpu_workers`` such that
``W * (baseline + max_shard_peak) <= fraction * available_RAM``, floored at
1 so the run always makes progress (capping to 1 routes ``_run_joint_shards``
to the in-process sequential path). When available memory is unknown, returns
``cpu_workers`` unchanged.
"""
available = _available_memory_bytes()
if available is None or not payloads:
return cpu_workers
max_shard_peak = max(_estimate_shard_peak_bytes(p, num_chains) for p in payloads)
per_worker = _WORKER_BASELINE_BYTES + max_shard_peak
budget = int(available * _WORKER_MEMORY_FRACTION)
mem_workers = max(1, budget // per_worker)
capped = int(min(cpu_workers, mem_workers))
if capped < cpu_workers:
logger.info(
"[CMC joint] memory-capping workers %d -> %d "
"(avail=%.1fGB, ~%.2fGB/worker incl. baseline, %d shards)",
cpu_workers,
capped,
available / 1e9,
per_worker / 1e9,
len(payloads),
)
return capped
def _fmt_time(secs: float) -> str:
if secs < 60:
return f"{secs:.0f}s"
elif secs < 3600:
return f"{secs / 60:.1f}min"
else:
return f"{secs / 3600:.1f}h"
def _log_runtime_estimate(
log,
n_shards: int,
n_chains: int,
n_warmup: int,
n_samples: int,
avg_points_per_shard: int,
n_workers: int | None = None,
) -> float:
"""Log a rough CMC runtime estimate and return it in seconds."""
if n_workers is None:
n_workers = _estimate_n_workers()
jit_overhead = 45 + (avg_points_per_shard / 10_000) * 20
iters = n_chains * (n_warmup + n_samples)
secs_per_iter = 0.2 + (avg_points_per_shard / 100_000) * 0.3
total_per_shard = jit_overhead + iters * secs_per_iter
batches = (n_shards + n_workers - 1) // n_workers
total = batches * total_per_shard
log.info(
"Runtime estimate: %s total (%d shards / %d workers, ~%s/shard)",
_fmt_time(total),
n_shards,
n_workers,
_fmt_time(total_per_shard),
)
return total
def _result_dict_to_cmc_result(
result_dict: dict[str, Any],
config: CMCConfig,
) -> CMCResult:
"""Convert a _run_shard_worker result dict to a CMCResult.
Workers return raw sample dicts; this helper computes ArviZ diagnostics
(R-hat, ESS, BFMI) and constructs the full CMCResult expected by
_combine_shard_posteriors().
"""
if not result_dict.get("success", False):
param_names: list[str] = result_dict.get("param_names", [])
return _create_failed_result(
param_names, result_dict.get("error", "shard failed")
)
samples_np: dict[str, np.ndarray] = result_dict["samples"]
param_names = result_dict["param_names"]
n_chains: int = result_dict["n_chains"]
n_samples: int = result_dict["n_samples"]
extra_fields: dict[str, np.ndarray] = result_dict.get("extra_fields", {})
duration: float = result_dict.get("duration", 0.0)
stats: dict[str, Any] = result_dict.get("stats", {})
num_divergent: int = stats.get("num_divergent", 0)
n_warmup: int = stats.get("n_warmup", config.num_warmup)
# Reshape (n_chains * n_samples,) → (n_chains, n_samples) for ArviZ.
# Shape contract: each sample array must have exactly n_chains * n_samples
# elements; if a worker reports stale or truncated stats (e.g. on timeout)
# the configured fallback values won't match the actual sample count and
# numpy's reshape will raise a cryptic "cannot reshape" message deep in
# ArviZ. Validate explicitly so the failure mode is one log line, not a
# ten-frame traceback (deep-RCA F5, Prevention 4).
_expected_size = int(n_chains) * int(n_samples)
for _k, _v in samples_np.items():
_actual = int(np.asarray(_v).size)
if _actual != _expected_size:
raise ValueError(
f"Shard sample-array shape contract violated for parameter "
f"{_k!r}: expected n_chains × n_samples = "
f"{n_chains} × {n_samples} = {_expected_size} elements, "
f"got {_actual}. This usually means a worker returned a "
"partial result (timeout, abort, or signal) while the result "
"dict's n_samples fell back to the configured value. "
"Inspect worker logs for the affected shard."
)
idata: az.InferenceData | None = None
summary: Any = None
try:
posterior_dict = {
k: v.reshape(n_chains, n_samples) for k, v in samples_np.items()
}
idata_kwargs: dict[str, Any] = {"posterior": posterior_dict}
if "energy" in extra_fields:
try:
idata_kwargs["sample_stats"] = {
"energy": extra_fields["energy"].reshape(n_chains, n_samples)
}
except (ValueError, AttributeError):
pass
idata = az.from_dict(idata_kwargs)
if param_names:
with _silence_arviz_diagnostic_warnings():
summary = az.summary(idata, var_names=param_names, ci_prob=0.95)
except Exception as _exc: # noqa: BLE001
logger.warning("ArviZ summary failed for shard result: %s", _exc)
physics_samples = {k: np.asarray(v) for k, v in samples_np.items()}
posterior_mean, posterior_std, r_hat, ess_bulk, ess_tail = _extract_posterior_stats(
param_names, physics_samples, summary
)
credible_intervals = _extract_credible_intervals(
param_names, physics_samples, summary
)
bfmi: list[float] | None = None
bfmi_compute_failed = False
if idata is not None:
bfmi, bfmi_compute_failed = _compute_bfmi(idata)
r_hat_finite = r_hat[~np.isnan(r_hat)]
ess_finite = ess_bulk[~np.isnan(ess_bulk)]
convergence_passed = bool(
len(r_hat_finite) > 0
and np.all(r_hat_finite < config.max_r_hat)
and len(ess_finite) > 0
and np.all(ess_finite > config.min_ess)
)
# BFMI is advisory for per-shard results (homodyne parity).
# Homodyne check_convergence uses R-hat + ESS as the sole hard gates.
# Per-shard BFMI < 0.3 is expected when chains start near a parameter
# boundary (boundary reflection drives short trajectories → low BFMI).
# Using BFMI as a hard gate here causes 100% shard rejection on
# boundary-adjacent warm-starts such as alpha_sample=-2.0.
_shard_min_bfmi: float | None = None
if bfmi is not None and not bfmi_compute_failed:
_shard_min_bfmi = float(np.nanmin(np.asarray(bfmi, dtype=float)))
if _shard_min_bfmi < config.min_bfmi:
logger.warning(
"Shard: low BFMI=%.3f < %.2f (advisory; not a convergence gate). "
"Chains may be near a parameter boundary.",
_shard_min_bfmi,
config.min_bfmi,
)
if bfmi_compute_failed:
logger.debug(
"Shard: BFMI unavailable (az.bfmi failed); "
"convergence determined by R-hat and ESS only"
)
total_iters = n_chains * n_samples
divergence_rate = num_divergent / total_iters if total_iters > 0 else 0.0
# Escalate to WARNING when a shard fails so failures are visible in
# production logs without --debug. PASS shards stay at DEBUG (avoid
# flooding 47-shard runs with INFO noise).
_diag_log = logger.warning if not convergence_passed else logger.debug
_diag_log(
"Shard diagnostics: convergence=%s, max_r_hat=%.3f (threshold=%.2f), "
"min_ess=%.0f (threshold=%d), bfmi=%s, divergence_rate=%.1f%%",
"PASS" if convergence_passed else "FAIL",
float(np.nanmax(r_hat)) if len(r_hat_finite) > 0 else float("nan"),
config.max_r_hat,
float(np.nanmin(ess_bulk)) if len(ess_finite) > 0 else float("nan"),
config.min_ess,
f"{_shard_min_bfmi:.3f}" if _shard_min_bfmi is not None else "N/A",
divergence_rate * 100,
)
return CMCResult(
parameter_names=param_names,
posterior_mean=posterior_mean,
posterior_std=posterior_std,
credible_intervals=credible_intervals,
convergence_passed=convergence_passed,
r_hat=r_hat,
ess_bulk=ess_bulk,
ess_tail=ess_tail,
bfmi=bfmi,
samples=physics_samples,
map_estimate=posterior_mean.copy(),
num_warmup=n_warmup,
num_samples=n_samples,
num_chains=n_chains,
wall_time_seconds=duration,
metadata={
"num_divergent": num_divergent,
"divergence_rate": divergence_rate,
},
)
def _extract_posterior_stats(
output_names: list[str],
physics_samples: dict[str, np.ndarray],
summary: Any,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Extract mean, std, R-hat, ESS-bulk, ESS-tail from an ArviZ summary.
Falls back to direct sample statistics when the summary is unavailable.
Args:
output_names: Ordered list of parameter names.
physics_samples: Dictionary mapping parameter names to 1-D sample
arrays in physics space.
summary: ArviZ summary DataFrame, or ``None``.
Returns:
5-tuple of NumPy arrays ``(mean, std, r_hat, ess_bulk, ess_tail)``,
each of length ``len(output_names)``.
"""
if summary is not None and len(summary) > 0:
posterior_mean = np.array(
[
float(summary.loc[name, "mean"]) if name in summary.index else 0.0
for name in output_names
]
)
posterior_std = np.array(
[
float(summary.loc[name, "sd"]) if name in summary.index else 0.0
for name in output_names
]
)
r_hat = np.array(
[
float(summary.loc[name, "r_hat"]) if name in summary.index else np.nan
for name in output_names
]
)
ess_bulk = np.array(
[
float(summary.loc[name, "ess_bulk"])
if name in summary.index
else np.nan
for name in output_names
]
)
ess_tail = np.array(
[
float(summary.loc[name, "ess_tail"])
if name in summary.index
else np.nan
for name in output_names
]
)
else:
posterior_mean = np.array(
[
float(np.mean(physics_samples[name]))
if name in physics_samples
else 0.0
for name in output_names
]
)
posterior_std = np.array(
[
float(np.std(physics_samples[name])) if name in physics_samples else 0.0
for name in output_names
]
)
r_hat = np.full(len(output_names), np.nan)
ess_bulk = np.full(len(output_names), np.nan)
ess_tail = np.full(len(output_names), np.nan)
return posterior_mean, posterior_std, r_hat, ess_bulk, ess_tail
def _extract_credible_intervals(
output_names: list[str],
physics_samples: dict[str, np.ndarray],
summary: Any,
) -> dict[str, dict[str, float]]:
"""Extract 95 % credible intervals from ArviZ summary or raw samples.
Args:
output_names: Ordered list of parameter names.
physics_samples: Dictionary mapping parameter names to sample arrays.
summary: ArviZ summary DataFrame produced by ``az.summary(ci_prob=0.95)``,
or ``None`` to fall back to raw percentiles. Modern ArviZ (≥ 0.12)
uses columns ``"eti_2.5%"`` / ``"eti_97.5%"``; older versions use
``"hdi_2.5%"`` / ``"hdi_97.5%"``. Both are tried before the
raw-percentile fallback.
Returns:
Dict mapping parameter names to ``{"2.5%": lb, "97.5%": ub}``.
"""
credible_intervals: dict[str, dict[str, float]] = {}
for name in output_names:
if summary is not None and name in summary.index:
try:
lb = float(summary.loc[name, "eti_2.5%"])
ub = float(summary.loc[name, "eti_97.5%"])
except KeyError:
try:
lb = float(summary.loc[name, "hdi_2.5%"])
ub = float(summary.loc[name, "hdi_97.5%"])
except KeyError:
lb, ub = None, None
if lb is not None and ub is not None:
credible_intervals[name] = {"2.5%": lb, "97.5%": ub}
continue
if name in physics_samples:
s = physics_samples[name]
credible_intervals[name] = {
"2.5%": float(np.percentile(s, 2.5)),
"97.5%": float(np.percentile(s, 97.5)),
}
return credible_intervals
def _compute_bfmi(idata: az.InferenceData) -> tuple[list[float] | None, bool]:
"""Compute BFMI from ArviZ InferenceData, returning (bfmi, failed).
Args:
idata: ArviZ InferenceData object with sample stats.
Returns:
Tuple of ``(bfmi_list, compute_failed)``. ``bfmi_list`` is a
list of per-chain BFMI values, or ``None`` if unavailable.
``compute_failed`` is ``True`` when an exception was raised.
"""
bfmi: list[float] | None = None
bfmi_compute_failed = False
try:
with _silence_arviz_diagnostic_warnings():
bfmi_result = az.bfmi(idata)
# az.bfmi() return type varies across ArviZ versions:
# - xr.DataArray → .values is a numpy array attribute (non-callable)
# - dict → .values is a bound method (callable)
# - list / ndarray → iterate directly
if isinstance(bfmi_result, dict):
bfmi = list(bfmi_result.values())
elif hasattr(bfmi_result, "values"):
attr = bfmi_result.values
materialized = attr() if callable(attr) else attr
bfmi = list(cast("Iterable[float]", materialized))
elif isinstance(bfmi_result, (list, np.ndarray)):
bfmi = list(bfmi_result)
except (TypeError, KeyError) as e:
logger.warning("Could not compute BFMI: %s", e)
bfmi_compute_failed = True
return bfmi, bfmi_compute_failed
def _create_failed_result(parameter_names: list[str], message: str) -> CMCResult:
"""Create a failed CMC result with zero statistics.
Args:
parameter_names: List of parameter names for the result.
message: Error message to store in ``metadata["error"]``.
Returns:
CMCResult with ``convergence_passed=False`` and zero arrays.
"""
n_params = len(parameter_names)
return CMCResult(
parameter_names=parameter_names,
posterior_mean=np.zeros(n_params),
posterior_std=np.zeros(n_params),
credible_intervals={},
convergence_passed=False,
metadata={"error": message},
)
def _grid_indices(grid: np.ndarray, values: np.ndarray, *, axis: str) -> np.ndarray:
"""Map ``values`` onto exact positions in the sorted ``grid``, bounded.
Guards every ``np.searchsorted`` that feeds a per-point gather in the
pooled CMC model. Two failure modes are addressed:
* **Out-of-range index.** ``np.searchsorted(grid, v)`` returns ``len(grid)``
for any ``v`` at or above ``grid[-1]`` (floating-point round-up on the
last lag is enough). That index is one past the last valid position and
produces an out-of-range gather. We clip into ``[0, len(grid) - 1]`` so
the returned indices can never exceed the grid.
* **Off-grid value.** ``side="left"`` lands on the right neighbour, so we
also test the left neighbour and keep whichever grid point is closer
(nearest-neighbour snap). Floating-point noise on an otherwise-regular
meshgrid is absorbed; a value that is genuinely between grid points
(e.g. a multi-tau lag that does not tile ``model.t``) stays far from both
neighbours and raises with the offending value, rather than silently
gathering the wrong cell.
Args:
grid: Sorted, strictly increasing 1-D time axis.
values: Per-point time coordinates to locate on ``grid``.
axis: Label (``"t1"``/``"t2"``) used only in the error message.
Returns:
``int32`` indices into ``grid``, all within ``[0, len(grid) - 1]``.
Raises:
ValueError: If any value lies off the grid beyond a spacing-scaled
tolerance.
"""
grid = np.asarray(grid, dtype=np.float64)
values = np.asarray(values, dtype=np.float64)
n = int(grid.size)
if n == 0:
raise ValueError(f"_grid_indices: empty grid for axis {axis!r}")
if n == 1:
return np.zeros(values.shape, dtype=np.int32)
right = np.clip(np.searchsorted(grid, values), 0, n - 1)
left = np.clip(right - 1, 0, n - 1)
use_left = np.abs(values - grid[left]) < np.abs(values - grid[right])
idx = np.where(use_left, left, right).astype(np.int32)
diffs = np.diff(grid)
min_spacing = float(diffs[diffs > 0].min()) if np.any(diffs > 0) else 1.0
tol = max(min_spacing * 1e-3, 1e-12)
residual = np.abs(grid[idx] - values)
if np.any(residual > tol):
worst = int(np.argmax(residual))
raise ValueError(
f"_grid_indices: {axis} value {values[worst]:.6g} is off the time "
f"grid (nearest grid point {grid[idx[worst]]:.6g}, residual "
f"{residual[worst]:.3g} > tol {tol:.3g}); grid spans "
f"[{grid[0]:.6g}, {grid[-1]:.6g}] with {n} points. The pooled CMC "
"model requires (t1, t2) to tile model.t as a regular meshgrid."
)
return idx
# Physics parameters whose NumPyro sample-site names match the NLSQ parameter
# names 1:1 on the joint pooled path (original space, no z-transform).
_TRANSPORT_TRIPLES = (
("D0_ref", "alpha_ref", "D_offset_ref"),
("D0_sample", "alpha_sample", "D_offset_sample"),
)
def _mcmc_safe_d0_component(
d0: float,
alpha: float,
d_offset: float,
q: float,
dt: float,
time_grid: np.ndarray,
*,
target_g1: float = 0.5,
g1_threshold: float = 0.1,
) -> tuple[float, float] | None:
"""Scale one transport triple if its initial D0 drives g1 → 0.
Heterodyne adaptation of homodyne ``_compute_mcmc_safe_d0`` (sampler.py).
Homodyne's single-component model checks one ``D0``/``alpha``/``D_offset``;
the heterodyne c1 product carries one ``exp(-q²∫J(t)dt)`` factor PER
transport component (reference and sample), so the same vanishing-gradient
guard is applied independently to each triple. Returns ``(new_d0,
new_d_offset)`` when a scaling is warranted, else ``None``. Pure NumPy —
runs at the I/O boundary before any JAX sampling, so the gradient-safe
floor convention (``jnp.where``) does not apply here.
"""
if not (np.isfinite(d0) and np.isfinite(alpha) and np.isfinite(d_offset)):
return None
if time_grid is None or len(time_grid) < 2:
return None
try:
epsilon = 1e-10
time_safe = np.asarray(time_grid, dtype=np.float64) + epsilon
d_grid = d0 * (time_safe**alpha) + d_offset
d_grid = np.where(d_grid > 1e-10, d_grid, 1e-10)
trap_avg = 0.5 * (d_grid[:-1] + d_grid[1:])
cumsum = np.concatenate([[0.0], np.cumsum(trap_avg)])
n = len(cumsum)
idx_low = n // 4
idx_high = 3 * n // 4
integral_estimate = abs(cumsum[idx_high] - cumsum[idx_low])
prefactor = q**2 * dt
if prefactor <= 0.0:
return None
log_g1 = -prefactor * integral_estimate
g1_estimate = np.exp(max(log_g1, -700.0))
if g1_estimate >= g1_threshold:
return None
target_integral = -np.log(target_g1) / prefactor
scale_factor = (
target_integral / integral_estimate if integral_estimate > 0 else 0.01
)
new_d0 = max(d0 * scale_factor, 1.0)
new_d_offset = max(d_offset * scale_factor, -1e6)
return float(new_d0), float(new_d_offset)
except (FloatingPointError, ValueError, OverflowError):
return None
def _build_joint_init_values(
*,
effective_mode: str,
space: ParameterSpace,
nlsq_results: list[NLSQResult] | None,
n_phi: int,
per_angle_contrast: np.ndarray | None,
per_angle_offset: np.ndarray | None,
noise_scale: float,
q: float,
dt: float,
time_grid: np.ndarray,
) -> dict[str, float]:
"""Build ``init_to_value`` seeds for the joint pooled model's sample sites.
Homodyne parity (``priors.build_init_values_dict`` +
``sampler.run_nuts_sampling``): warm-start NUTS from the NLSQ optimum when
present, else registry/space defaults. The joint pooled heterodyne model
samples in ORIGINAL space (physics names verbatim, ``contrast_{i}`` /
``offset_{i}`` per angle, plus ``sigma``) — no z-space transform, so the
init dict keys are the raw site names.
Sources, in priority order:
- Physics params: ``nlsq_results[0]`` if present AND converged, else the
registry ``prior_mean`` (falling back to ``space.values``).
- ``contrast_{i}`` / ``offset_{i}`` (individual mode only): per-angle
quantile estimates already computed by the caller.
- ``sigma``: ``noise_scale`` (the data-driven prior centre).
``constant`` / ``constant_averaged`` skip contrast/offset (those are fixed,
not sampled). ``auto`` / ``averaged`` sample a single ``contrast`` /
``offset``; we leave those to ``init_to_value``'s missing-site tolerance.
"""
# Local import: priors.py defines clamp_params_to_interior late in the
# module, so a top-level import would race the partial-init circular import
# between core <-> priors. Imported here (mirrors the local numpyro.infer
# import pattern elsewhere in this module).
from heterodyne.optimization.cmc.priors import clamp_params_to_interior
init: dict[str, float] = {}
physics_names = [n for n in space.varying_names if n not in ("contrast", "offset")]
warmstart = (
nlsq_results[0]
if nlsq_results and getattr(nlsq_results[0], "success", False)
else None
)
nlsq_dict = warmstart.params_dict if warmstart is not None else {}
for name in physics_names:
if name in nlsq_dict and np.isfinite(nlsq_dict[name]):
init[name] = float(nlsq_dict[name])
else:
info = DEFAULT_REGISTRY[name]
seed = info.prior_mean
if seed is None or not np.isfinite(seed):
seed = space.values.get(name, info.default)
init[name] = float(seed)
# Shift physics seeds off the TruncatedNormal walls (homodyne ±1% margin
# parity; heterodyne uses the registry-aware 5% interior clamp so a chain
# never initialises at a reflecting bound where the log-prob is -inf and
# NUTS step-size adaptation collapses).
_phys_arr = np.array([init[n] for n in physics_names], dtype=float)
_phys_clamped, _ = clamp_params_to_interior(_phys_arr, physics_names)
for n, v in zip(physics_names, _phys_clamped, strict=True):
init[n] = float(v)
# Vanishing-gradient guard, per transport component (homodyne parity).
for d0_name, alpha_name, doff_name in _TRANSPORT_TRIPLES:
if d0_name in init and alpha_name in init and doff_name in init:
adjusted = _mcmc_safe_d0_component(
init[d0_name],
init[alpha_name],
init[doff_name],
q,
dt,
time_grid,
)
if adjusted is not None:
new_d0, new_doff = adjusted
logger.warning(
"[CMC joint] MCMC-safe init: %s=%.4g causes g1→0 "
"(vanishing gradients); scaling to %.4g (and %s to %.4g) "
"for NUTS exploration stability.",
d0_name,
init[d0_name],
new_d0,
doff_name,
new_doff,
)
init[d0_name] = new_d0
init[doff_name] = new_doff
if effective_mode in ("individual", "scaled"):
if per_angle_contrast is not None and per_angle_offset is not None:
# The sampled sites are contrast_{i}/offset_{i}, which are NOT in
# the registry — clamp against the base "contrast"/"offset" interior
# bounds (same 5% margin as physics) so the data-driven quantile
# estimates never sit on the TruncatedNormal wall (-inf log-prob).
for i in range(n_phi):
c_clamped, _ = clamp_params_to_interior(
np.array([per_angle_contrast[i]], dtype=float), ["contrast"]
)
o_clamped, _ = clamp_params_to_interior(
np.array([per_angle_offset[i]], dtype=float), ["offset"]
)
init[f"contrast_{i}"] = float(c_clamped[0])
init[f"offset_{i}"] = float(o_clamped[0])
if np.isfinite(noise_scale) and noise_scale > 0:
init["sigma"] = float(noise_scale)
return init
def _create_joint_init_strategy(
initial_values: dict[str, float] | None,
config: CMCConfig,
) -> Any:
"""Mirror homodyne ``create_init_strategy`` for the joint pooled path.
When ``initial_values`` is non-empty, return ``init_to_value`` over those
sites (original-space names — the joint model samples in original space).
Otherwise fall back to the configured strategy (default
``init_to_median``). ``init_to_value`` tolerates extra/missing sites, so
callers may over- or under-specify safely.
"""
from numpyro.infer import initialization as numpyro_init
if initial_values:
logger.info(
"[CMC joint] init_to_value wired for %d sample sites: %s",
len(initial_values),
sorted(initial_values)[:6],
)
return numpyro_init.init_to_value(values=dict(initial_values))
fallback_map = {
"init_to_median": numpyro_init.init_to_median,
"init_to_sample": numpyro_init.init_to_sample,
}
fallback = fallback_map.get(config.init_strategy, numpyro_init.init_to_median)
logger.info(
"[CMC joint] no initial values; falling back to %s",
getattr(fallback, "__name__", str(fallback)),
)
return fallback()
def _joint_pooled_nuts_run(
*,
effective_mode: str,
data: np.ndarray,
time_grid: np.ndarray,
q: float,
dt: float,
phi_unique: np.ndarray,
phi_indices: np.ndarray,
i1_indices: np.ndarray,
i2_indices: np.ndarray,
noise_scale: float,
space: ParameterSpace,
fixed_contrast: np.ndarray | float | None,
fixed_offset: np.ndarray | float | None,
num_shards_model: int,
config: CMCConfig,
n_phi: int,
rng_seed: int,
result_num_shards: int,
keep_samples: bool = True,
num_warmup: int | None = None,
num_samples: int | None = None,
initial_values: dict[str, float] | None = None,
) -> CMCResult:
"""Build the pooled joint model, run one NUTS pass, assemble a CMCResult.
Shared by the single-pass (K=1) and per-shard (K>1) joint multi-phi
paths so both produce identical posterior-extraction semantics. The
``num_shards_model`` argument flows into the model's prior tempering
(priors widened by ``sqrt(K)``; the Scott et al. 2016 Consensus MC
correction); ``result_num_shards`` is stamped onto the returned
``CMCResult.num_shards``. ``phi_unique``/``phi_indices`` MUST use the
global angle set so per-shard parameter vectors align for consensus.
``num_warmup``/``num_samples`` override the config defaults (used by the
adaptive per-shard scaling); ``None`` falls back to ``config`` values.
"""
eff_warmup = config.num_warmup if num_warmup is None else int(num_warmup)
eff_samples = config.num_samples if num_samples is None else int(num_samples)
model_callable = get_heterodyne_pooled_model_for_mode(
effective_mode,
data=jnp.asarray(data),
t=jnp.asarray(time_grid),
q=float(q),
dt=float(dt),
phi_unique=jnp.asarray(phi_unique),
phi_indices=jnp.asarray(phi_indices),
i1_indices=jnp.asarray(i1_indices),
i2_indices=jnp.asarray(i2_indices),
noise_scale=noise_scale,
space=space,
fixed_contrast=fixed_contrast,
fixed_offset=fixed_offset,
num_shards=num_shards_model,
)
start_time = time.perf_counter()
init_strategy = _create_joint_init_strategy(initial_values, config)
kernel = NUTS(
model_callable,
target_accept_prob=config.target_accept_prob,
dense_mass=config.dense_mass,
init_strategy=init_strategy,
)
mcmc = MCMC(
kernel,
num_warmup=eff_warmup,
num_samples=eff_samples,
num_chains=config.num_chains,
chain_method=config.chain_method,
progress_bar=False,
)
mcmc.run(jax.random.PRNGKey(rng_seed))
samples_raw = mcmc.get_samples(group_by_chain=True)
samples = {k: np.asarray(v) for k, v in samples_raw.items()}
extra_fields = mcmc.get_extra_fields(group_by_chain=True)
divergences = int(np.sum(np.asarray(extra_fields.get("diverging", []))))
wall_time = time.perf_counter() - start_time
physics_names = [n for n in space.varying_names if n not in ("contrast", "offset")]
contrast_names = [f"contrast_{i}" for i in range(n_phi)]
offset_names = [f"offset_{i}" for i in range(n_phi)]
parameter_names = physics_names + contrast_names + offset_names
posterior_mean = np.zeros(len(parameter_names))
posterior_std = np.zeros(len(parameter_names))
for i, name in enumerate(parameter_names):
if name in samples:
flat = samples[name].reshape(-1)
posterior_mean[i] = float(np.nanmean(flat))
posterior_std[i] = float(np.nanstd(flat))
n_phys = len(physics_names)
mean_contrast = np.array([posterior_mean[n_phys + i] for i in range(n_phi)])
std_contrast = np.array([posterior_std[n_phys + i] for i in range(n_phi)])
mean_offset = np.array([posterior_mean[n_phys + n_phi + i] for i in range(n_phi)])
std_offset = np.array([posterior_std[n_phys + n_phi + i] for i in range(n_phi)])
total_iters = eff_samples * config.num_chains
divergence_rate = divergences / total_iters if total_iters > 0 else 0.0
if result_num_shards > 1:
# Per-shard result feeding Consensus MC: gate on divergence RATE (not
# divergences == 0) so a handful of divergences does not drop the whole
# shard from the consensus. ``_combine_shard_posteriors`` requires both
# ``metadata["divergence_rate"]`` (rate gate) AND ``convergence_passed``
# to admit a shard; mirror _result_dict_to_cmc_result here.
max_div_rate = getattr(config, "max_divergence_rate", 0.10)
convergence_passed = divergence_rate <= max_div_rate
else:
# Single-pass result returned directly to the user: keep strict status
# parity with the pre-sharding behaviour.
convergence_passed = divergences == 0
convergence_status = "converged" if convergence_passed else "divergences"
return CMCResult(
parameter_names=parameter_names,
posterior_mean=posterior_mean,
posterior_std=posterior_std,
credible_intervals={},
convergence_passed=convergence_passed,
convergence_status=convergence_status,
samples=samples if keep_samples else None,
num_warmup=eff_warmup,
num_samples=eff_samples,
num_chains=config.num_chains,
num_shards=result_num_shards,
divergences=divergences,
wall_time_seconds=float(wall_time),
mean_contrast=mean_contrast,
std_contrast=std_contrast,
mean_offset=mean_offset,
std_offset=std_offset,
per_angle_mode="individual",
metadata={
"joint_multi_phi": True,
"n_phi": n_phi,
"n_total": int(np.asarray(data).size),
"phi_unique": np.asarray(phi_unique).tolist(),
"num_divergent": divergences,
"divergence_rate": divergence_rate,
},
)
def _adaptive_shard_iters(config: CMCConfig, shard_size: int) -> tuple[int, int]:
"""Per-shard ``(warmup, samples)`` with homodyne-style adaptive scaling.
When ``config.adaptive_sampling`` is enabled, scale iterations by
``min(1, shard_size / 10000)`` with floors ``min_warmup`` / ``min_samples``.
Small shards converge with fewer iterations, so this cuts sequential
wall-clock without shortening larger shards. Disabled → full config values.
"""
if not getattr(config, "adaptive_sampling", False):
return config.num_warmup, config.num_samples
scale = min(1.0, shard_size / 10000.0)
warmup = max(
int(getattr(config, "min_warmup", 100)), int(config.num_warmup * scale)
)
samples = max(
int(getattr(config, "min_samples", 200)), int(config.num_samples * scale)
)
return warmup, samples
def _run_joint_pooled_shard_local(payload: dict[str, Any]) -> CMCResult:
"""Run one shard payload in-process (sequential path / parallel fallback)."""
return _joint_pooled_nuts_run(**payload)
def _run_joint_shards(
payloads: list[dict[str, Any]], config: CMCConfig, n_shards: int
) -> list[CMCResult]:
"""Run per-shard pooled NUTS, in parallel across worker processes when
possible, else sequentially in-process.
Parallel dispatch reuses the multiprocessing backend's proven spawn +
``_init_worker_jax`` machinery (float64, compilation cache, thread
pinning). ANY failure falls back to the validated sequential path so
correctness is never compromised.
"""
backend = (getattr(config, "backend_name", "auto") or "auto").lower()
cpu_workers = min(_estimate_n_workers(), n_shards)
# Bound concurrency by available RAM: each worker runs a full NUTS pass
# concurrently, so peak memory ≈ n_workers × per-shard. Capping to 1 under
# memory pressure routes to the sequential in-process path below.
n_workers = _memory_aware_worker_cap(payloads, cpu_workers, config.num_chains)
parallel_ok = (
backend in ("auto", "multiprocessing", "pjit", "jit")
and n_workers > 1
and n_shards > 1
)
if parallel_ok:
try:
from heterodyne.optimization.cmc.backends.multiprocessing import (
run_joint_pooled_shards_parallel,
)
logger.info(
"[CMC joint] dispatching %d shards across %d workers (parallel CMC)",
n_shards,
n_workers,
)
# Pre-flight ETA so the user sees the expected wall-clock up front
# (the run is otherwise silent until the first shard completes).
# Best-effort only: a logging estimate must never derail dispatch,
# so missing/odd payloads simply skip the estimate.
_sizes = [
int(np.asarray(d).size)
for p in payloads
if (d := p.get("data")) is not None
]
if _sizes:
_log_runtime_estimate(
logger,
n_shards=n_shards,
n_chains=config.num_chains,
n_warmup=config.num_warmup,
n_samples=config.num_samples,
avg_points_per_shard=sum(_sizes) // len(_sizes),
n_workers=n_workers,
)
results = run_joint_pooled_shards_parallel(
payloads,
n_workers=n_workers,
num_chains=config.num_chains,
per_shard_timeout=config.per_shard_timeout,
)
return results
except Exception: # noqa: BLE001 — degrade to sequential, never crash
logger.warning(
"[CMC joint] parallel shard dispatch failed; running %d shards "
"sequentially in-process",
n_shards,
exc_info=True,
)
# Sequential in-process path (single worker or parallel fallback). Emit a
# tqdm bar plus per-shard logging so progress is visible (homodyne parity).
from tqdm import tqdm
results: list[CMCResult] = []
total_div = 0
with tqdm(
total=n_shards,
desc="CMC joint shards (sequential)",
unit="shard",
disable=not getattr(config, "progress_bar", True),
) as pbar:
for si, p in enumerate(payloads):
sr = _run_joint_pooled_shard_local(p)
results.append(sr)
total_div += int(sr.divergences)
pbar.update(1)
pbar.set_postfix(shard=si, div=total_div, ok=bool(sr.convergence_passed))
logger.info(
"[CMC joint] shard %d/%d complete (divergences=%d, status=%s)",
si + 1,
n_shards,
int(sr.divergences),
sr.convergence_status,
)
return results
[docs]
def fit_cmc_multi_phi(
model: HeterodyneModel,
c2_data: np.ndarray | jnp.ndarray,
phi_angles: np.ndarray | list[float],
config: CMCConfig | None = None,
nlsq_results: list[NLSQResult] | None = None,
sigma: np.ndarray | float | None = None,
) -> CMCResult:
"""Joint multi-phi CMC entry point (homodyne parity).
Fits the pooled multi-phi data with shared 14 physics parameters and
per-angle contrast / offset. Mirrors ``homodyne``'s ``_fit_mcmc_jax_impl``:
small datasets run a single NUTS pass; large datasets (``n_total`` above
the single-shard limit, or when ``num_shards`` / ``max_points_per_shard``
is set explicitly) are sharded and combined by Consensus Monte Carlo —
NUTS is O(n) per leapfrog step, so a single pass over millions of pooled
points is intractable. Algorithm:
1. Pool ``c2_data`` (shape ``(n_phi, N, N)`` or ``(N, N)`` for n_phi=1)
into flat arrays ``(data, t1, t2, phi)`` of length ``n_phi * N * N``.
2. Run :func:`prepare_mcmc_data` to filter the diagonal and build a
:class:`PooledCMCData` container with ``phi_unique`` and
``phi_indices`` (homodyne layout).
3. Compute per-point grid indices ``i1_indices`` / ``i2_indices`` via
``searchsorted`` against ``model.t``.
4. Build the joint NumPyro model
:func:`xpcs_model_heterodyne_scaled` (per-angle sampled scaling +
shared physics + single pooled likelihood with t=0 boundary mask).
5. Run NUTS with the configured chains / warmup / samples.
6. Return a single :class:`CMCResult` with shared-physics posterior +
per-angle scaling posteriors (``mean_contrast`` / ``mean_offset``
arrays of length ``n_phi``).
The returned ``CMCResult`` reflects one joint inference — every angle
contributes to the same physics-parameter posterior, exactly as in
homodyne.
Parameters
----------
model:
Configured :class:`HeterodyneModel` whose time grid ``model.t``
defines the (N,) axis the pooled c2 was flattened from.
c2_data:
Experimental c2 of shape ``(n_phi, N, N)`` (multi-angle) or
``(N, N)`` (single-angle, treated as n_phi=1).
phi_angles:
Detector phi angles in degrees, length ``n_phi``.
config:
:class:`CMCConfig`. ``None`` uses defaults.
nlsq_results:
Optional per-angle NLSQ warm-start. Currently only used to log
warm-start status; future phases will translate to init_to_value.
sigma:
Optional measurement uncertainty estimate; ``None`` triggers a
MAD-based estimate via :func:`prepare_mcmc_data`.
Returns
-------
CMCResult
Joint multi-phi result. ``parameter_names`` lists the 14 physics
parameters followed by ``contrast_0..contrast_{n_phi-1}`` and
``offset_0..offset_{n_phi-1}``.
"""
del sigma # noise_scale comes from prepare_mcmc_data; future: honour user override
if config is None:
config = CMCConfig()
_validate_config_or_raise(config)
# ---- Phase 1: pool the stacked c2 into flat (n_total,) arrays ----
c2_np = np.asarray(c2_data, dtype=np.float64)
if c2_np.ndim == 2:
c2_np = c2_np[None, :, :]
if c2_np.ndim != 3:
raise ValueError(
f"fit_cmc_multi_phi: c2_data must be 2-D or 3-D, got ndim={c2_np.ndim}"
)
phi_arr = np.asarray(phi_angles, dtype=np.float64).reshape(-1)
n_phi_input, n_t1, n_t2 = c2_np.shape
if n_t1 != n_t2:
raise ValueError(
"fit_cmc_multi_phi: c2_data must be square in the time axes, "
f"got shape {c2_np.shape}"
)
if phi_arr.size != n_phi_input:
raise ValueError(
"fit_cmc_multi_phi: phi_angles length "
f"{phi_arr.size} does not match c2_data n_phi={n_phi_input}"
)
time_grid = np.asarray(model.t, dtype=np.float64)
if time_grid.size != n_t1:
raise ValueError(
f"fit_cmc_multi_phi: model.t has {time_grid.size} points but "
f"c2_data has {n_t1} time bins; they must match."
)
n_grid = int(time_grid.size)
i_idx, j_idx = np.meshgrid(np.arange(n_grid), np.arange(n_grid), indexing="ij")
t1_grid = time_grid[i_idx].ravel()
t2_grid = time_grid[j_idx].ravel()
data_flat = c2_np.reshape(n_phi_input * n_grid * n_grid)
t1_flat = np.tile(t1_grid, n_phi_input)
t2_flat = np.tile(t2_grid, n_phi_input)
phi_flat = np.repeat(phi_arr, n_grid * n_grid)
# ---- Phase 2: prepare pooled data (homodyne parity) ----
prepared: PooledCMCData = prepare_mcmc_data(
data_flat, t1_flat, t2_flat, phi_flat, filter_diagonal=True
)
# ---- Phase 3+: index -> build model -> shard -> sample -> combine ----
return _fit_cmc_pooled(model, prepared, time_grid, config, nlsq_results)
def _fit_cmc_pooled(
model: HeterodyneModel,
prepared: PooledCMCData,
time_grid: np.ndarray,
config: CMCConfig,
nlsq_results: list[NLSQResult] | None,
) -> CMCResult:
"""Shared pooled-CMC engine: index -> build model -> shard/NUTS -> combine.
Both joint entry points converge here once they hold a ``PooledCMCData``:
:func:`fit_cmc_multi_phi` (dense ``(n_phi, N, N)`` input) and
:func:`fit_mcmc_jax` (already-flat pooled input, which skips the dense
``(n_phi, N, N)`` reconstruction and meshgrid round-trip entirely — the
memory fix). ``time_grid`` is ``model.t``; every ``prepared.t1``/``t2``
value must lie on it (enforced by :func:`_grid_indices`).
"""
time_grid = np.asarray(time_grid, dtype=np.float64)
n_grid = int(time_grid.size)
# ---- Phase 3: per-point grid indices for the gather inside the model ----
i1_indices = _grid_indices(time_grid, prepared.t1, axis="t1")
i2_indices = _grid_indices(time_grid, prepared.t2, axis="t2")
logger.info(
"[CMC joint] n_phi=%d, n_total=%d, n_grid=%d, noise_scale=%.4e",
prepared.n_phi,
prepared.n_total,
n_grid,
prepared.noise_scale,
)
# ---- Phase 4: build the joint NumPyro model via mode dispatcher ----
space = model.param_manager.space
has_warmstart = bool(nlsq_results)
effective_mode = config.get_effective_per_angle_mode(
n_phi=prepared.n_phi,
nlsq_per_angle_mode=None, # Phase 2: not yet propagating NLSQ mode
has_nlsq_warmstart=has_warmstart,
)
logger.info("[CMC joint] effective per-angle mode: %r", effective_mode)
# Per-angle contrast/offset quantile estimates (homodyne
# ``estimate_per_angle_scaling`` parity). Used both as fixed model inputs
# for the constant modes AND as data-driven init_to_value seeds for the
# sampled ``contrast_{i}`` / ``offset_{i}`` sites in individual mode.
per_angle_contrast = np.zeros(prepared.n_phi)
per_angle_offset = np.zeros(prepared.n_phi)
for ai in range(prepared.n_phi):
mask = prepared.phi_indices == ai
vals = prepared.data[mask]
if vals.size == 0:
per_angle_contrast[ai] = 1.0
per_angle_offset[ai] = 0.0
else:
per_angle_offset[ai] = float(np.quantile(vals, 0.05))
per_angle_contrast[ai] = (
float(np.quantile(vals, 0.95)) - per_angle_offset[ai]
)
fixed_contrast_arg: np.ndarray | float | None = None
fixed_offset_arg: np.ndarray | float | None = None
if effective_mode == "constant":
fixed_contrast_arg = per_angle_contrast
fixed_offset_arg = per_angle_offset
elif effective_mode == "constant_averaged":
fixed_contrast_arg = float(per_angle_contrast.mean())
fixed_offset_arg = float(per_angle_offset.mean())
# ---- NUTS warm-start init (homodyne parity: ALWAYS init_to_value) -------
# The joint pooled model samples in original space, so initial_values keys
# are the raw site names (physics params verbatim, contrast_{i}/offset_{i},
# sigma). Sourced from NLSQ-if-converged, else registry/space defaults, plus
# data-driven quantile contrast/offset. This closes the cold-start gap that
# caused the 82%-divergence storm on CMC-only (no-NLSQ) runs.
initial_values = _build_joint_init_values(
effective_mode=effective_mode,
space=space,
nlsq_results=nlsq_results,
n_phi=prepared.n_phi,
per_angle_contrast=per_angle_contrast,
per_angle_offset=per_angle_offset,
noise_scale=prepared.noise_scale,
q=float(model.q),
dt=float(model.dt),
time_grid=time_grid,
)
logger.info(
"[CMC joint] init_to_value seeded for %d sites (NLSQ warm-start: %s)",
len(initial_values),
"yes"
if (nlsq_results and getattr(nlsq_results[0], "success", False))
else "no",
)
# ---- Phase 5: sharding decision (Consensus Monte Carlo, homodyne parity) ----
# Homodyne's _fit_mcmc_jax_impl pools the data and THEN shards it, running
# NUTS per shard and combining via consensus. It does NOT run a single NUTS
# pass over millions of points (NUTS is O(n) per leapfrog step). We mirror
# that: small data -> one pass; large data -> shard + consensus.
base_seed = config.seed if config.seed is not None else secrets.randbelow(2**31)
forced_shards = isinstance(config.num_shards, int) or isinstance(
config.max_points_per_shard, int
)
# CLAUDE.md: "NUTS is O(n) per leapfrog step. Never use 100K+ shard size."
_SINGLE_SHARD_LIMIT = 100_000
should_shard = forced_shards or prepared.n_total > _SINGLE_SHARD_LIMIT
def _single_pass(rng_seed: int) -> CMCResult:
return _joint_pooled_nuts_run(
effective_mode=effective_mode,
data=prepared.data,
time_grid=time_grid,
q=float(model.q),
dt=float(model.dt),
phi_unique=prepared.phi_unique,
phi_indices=prepared.phi_indices,
i1_indices=i1_indices,
i2_indices=i2_indices,
noise_scale=prepared.noise_scale,
space=space,
fixed_contrast=fixed_contrast_arg,
fixed_offset=fixed_offset_arg,
num_shards_model=1,
config=config,
n_phi=prepared.n_phi,
rng_seed=rng_seed,
result_num_shards=1,
keep_samples=True,
initial_values=initial_values,
)
if not should_shard:
logger.info(
"[CMC joint] single-shard NUTS: n_total=%d (<= %d); chains=%d, "
"warmup=%d, samples=%d",
prepared.n_total,
_SINGLE_SHARD_LIMIT,
config.num_chains,
config.num_warmup,
config.num_samples,
)
return _single_pass(base_seed)
# ---- Multi-shard Consensus Monte Carlo ----
# Target ~max_per_shard points/shard (homodyne aims ~10K); count capped by
# max_shards. Angle-balanced for multi-angle data so every shard sees all
# angles (consensus on shared global physics params requires homogeneous
# sub-posteriors); random for single-angle data.
explicit_shards = config.num_shards if isinstance(config.num_shards, int) else None
max_per_shard = (
config.max_points_per_shard
if isinstance(config.max_points_per_shard, int)
else 10_000
)
if prepared.n_phi > 1:
if config.sharding_strategy == "stratified":
logger.warning(
"[CMC joint] Overriding sharding_strategy='stratified' -> "
"'angle_balanced' for multi-angle joint CMC; stratified shards "
"create disjoint posteriors that violate Consensus MC assumptions "
"for shared global physics parameters."
)
shards = shard_pooled_angle_balanced(
prepared,
num_shards=explicit_shards,
max_points_per_shard=max_per_shard,
max_shards=500,
seed=base_seed,
)
strategy_used = "angle_balanced"
else:
shards = shard_pooled_random(
prepared,
num_shards=explicit_shards,
max_points_per_shard=max_per_shard,
max_shards=100,
seed=base_seed,
)
strategy_used = "random"
n_shards = len(shards)
if n_shards <= 1:
logger.info(
"[CMC joint] sharding collapsed to 1 shard; running single NUTS pass"
)
return _single_pass(base_seed)
logger.info(
"[CMC joint] Consensus Monte Carlo: %d points -> %d shards (strategy=%s); "
"NUTS per shard (chains=%d, warmup=%d, samples=%d)",
prepared.n_total,
n_shards,
strategy_used,
config.num_chains,
config.num_warmup,
config.num_samples,
)
shard_payloads: list[dict[str, Any]] = []
for si, shard in enumerate(shards):
# Map each shard point onto the GLOBAL angle index so every shard's
# parameter vector (contrast_i / offset_i) has identical length and
# ordering — a hard precondition of _combine_shard_posteriors.
g_phi_idx = np.argmin(
np.abs(shard.phi[:, None] - prepared.phi_unique[None, :]), axis=1
).astype(np.int32)
s_i1 = _grid_indices(time_grid, shard.t1, axis="t1")
s_i2 = _grid_indices(time_grid, shard.t2, axis="t2")
s_warmup, s_samples = _adaptive_shard_iters(config, shard.n_total)
shard_payloads.append(
{
"effective_mode": effective_mode,
"data": shard.data,
"time_grid": time_grid,
"q": float(model.q),
"dt": float(model.dt),
"phi_unique": prepared.phi_unique,
"phi_indices": g_phi_idx,
"i1_indices": s_i1,
"i2_indices": s_i2,
"noise_scale": shard.noise_scale,
"space": space,
"fixed_contrast": fixed_contrast_arg,
"fixed_offset": fixed_offset_arg,
"num_shards_model": n_shards,
"config": config,
"n_phi": prepared.n_phi,
"rng_seed": base_seed + 1 + si,
"result_num_shards": n_shards,
"keep_samples": False,
"num_warmup": s_warmup,
"num_samples": s_samples,
"initial_values": initial_values,
}
)
shard_results = _run_joint_shards(shard_payloads, config, n_shards)
_n_div_total = sum(int(sr.divergences) for sr in shard_results)
logger.info(
"[CMC joint] all %d shards complete (%d total divergences)",
n_shards,
_n_div_total,
)
combined = _combine_shard_posteriors(shard_results, config, n_shards, base_seed)
combined.num_shards = n_shards
combined.metadata.update(
{
"joint_multi_phi": True,
"n_phi": prepared.n_phi,
"n_total": prepared.n_total,
"phi_unique": prepared.phi_unique.tolist(),
"sharding_strategy": strategy_used,
"num_shards": n_shards,
}
)
logger.info(
"[CMC joint] consensus complete: combined %d shards, status=%s",
n_shards,
combined.convergence_status,
)
return combined
[docs]
def fit_mcmc_jax(
data: np.ndarray,
t1: np.ndarray,
t2: np.ndarray,
phi: np.ndarray,
q: float,
L: float, # noqa: ARG001 — accepted for homodyne parity
analysis_mode: str, # noqa: ARG001 — accepted for homodyne parity
method: str = "mcmc", # noqa: ARG001 — accepted for homodyne parity
cmc_config: dict[str, Any] | CMCConfig | None = None,
initial_values: dict[str, float] | None = None,
parameter_space: Any | None = None,
dt: float | None = None,
output_dir: Any | None = None, # noqa: ARG001 — accepted for homodyne parity
progress_bar: bool = True, # noqa: ARG001 — accepted for homodyne parity
run_id: str | None = None, # noqa: ARG001 — accepted for homodyne parity
nlsq_result: NLSQResult | dict | None = None,
**kwargs: Any,
) -> CMCResult:
"""Homodyne-parity entry point for heterodyne CMC.
Mirrors ``homodyne.optimization.cmc.fit_mcmc_jax``'s pooled-array call
signature and routes to heterodyne's native ``fit_cmc_jax`` /
``fit_cmc_sharded``. This adapter exists so cross-package CLI / driver
code that follows homodyne's pooled-data convention can call
heterodyne's CMC pipeline without reshaping by hand.
Heterodyne's native API is ``fit_cmc_jax(model, c2_data, phi_angle,
config, ...)`` — new heterodyne code should prefer that directly.
Parameters
----------
data, t1, t2, phi : np.ndarray
Pooled C2 values and time/angle coordinates, all shape ``(n_total,)``.
``(t1, t2)`` must be a flattened regular meshgrid for the reverse
reshape to succeed.
q : float
Wavevector magnitude (Å⁻¹).
L : float
Stator-rotor gap. Accepted for homodyne parity; heterodyne's physics
model uses absolute time scaling, not L-normalised dimensionless time.
analysis_mode : str
Accepted for homodyne parity; heterodyne always uses its
two-component model regardless of this value.
method, output_dir, progress_bar, run_id :
Accepted for homodyne parity; consumed by heterodyne's native
pipeline where applicable.
cmc_config : dict or CMCConfig, optional
CMC configuration. Dicts are converted via ``CMCConfig.from_dict``.
initial_values : dict[str, float], optional
Initial parameter values applied to the constructed model.
parameter_space : ParameterSpace, optional
Pre-built ParameterSpace. When ``None``, a default one is built
from ``DEFAULT_REGISTRY``.
dt : float, optional
Time step. When ``None``, inferred from
``np.diff(np.unique(t1 ∪ t2))``.
nlsq_result : NLSQResult or dict, optional
Optional NLSQ warm-start. Dicts are ignored with a warning since
heterodyne's native warm-start requires an ``NLSQResult`` instance.
Returns
-------
CMCResult
Result of the CMC fit.
Raises
------
ValueError
If input shapes mismatch or the (t1, t2) pooled grid is not
recoverable.
NotImplementedError
If pooled data contains multiple distinct phi angles; call this
adapter once per angle, or use
``heterodyne.cli.optimization_runner.run_cmc`` for orchestrated
multi-angle CMC.
"""
if kwargs:
logger.debug(
"fit_mcmc_jax ignoring kwargs (homodyne parity): %s",
sorted(kwargs.keys()),
)
if isinstance(cmc_config, CMCConfig):
config = cmc_config
else:
config = CMCConfig.from_dict(cmc_config or {})
_validate_config_or_raise(config)
data_arr = np.asarray(data, dtype=np.float64)
t1_arr = np.asarray(t1, dtype=np.float64)
t2_arr = np.asarray(t2, dtype=np.float64)
phi_arr = np.asarray(phi, dtype=np.float64)
if not (data_arr.shape == t1_arr.shape == t2_arr.shape == phi_arr.shape):
raise ValueError(
"fit_mcmc_jax: pooled data/t1/t2/phi shape mismatch: "
f"data={data_arr.shape} t1={t1_arr.shape} "
f"t2={t2_arr.shape} phi={phi_arr.shape}"
)
# Recover the regular (t, t) grid the pooled arrays were flattened from.
t_unique = np.unique(np.concatenate([t1_arr, t2_arr]))
n_t = int(t_unique.size)
unique_phi = np.unique(phi_arr)
n_phi = int(unique_phi.size)
if data_arr.size != n_phi * n_t * n_t:
raise ValueError(
f"fit_mcmc_jax: pooled data size {data_arr.size} does not match "
f"recovered grid {n_phi}x{n_t}x{n_t}={n_phi * n_t * n_t}; "
"heterodyne CMC requires a regular meshgrid of (phi, t1, t2)."
)
phi_indices = np.argmin(
np.abs(phi_arr[:, None] - unique_phi[None, :]), axis=1
).astype(np.int32)
i1 = _grid_indices(t_unique, t1_arr, axis="t1")
i2 = _grid_indices(t_unique, t2_arr, axis="t2")
# #3 memory fix: validate full (phi, t1, t2) coverage WITHOUT materialising
# the dense (n_phi, n_t, n_t) float matrix. Scatter into a 1-byte presence
# mask (8x smaller, freed immediately); combined with the size check above,
# a fully-covered mask proves a gap- and duplicate-free meshgrid. ``order``
# restores the canonical angle-major / i1-major layout the old dense
# round-trip produced, so downstream results stay bit-identical.
cell_ids = (phi_indices.astype(np.int64) * n_t + i1) * n_t + i2
seen = np.zeros(n_phi * n_t * n_t, dtype=bool)
seen[cell_ids] = True
if not seen.all():
raise ValueError(
"fit_mcmc_jax: pooled (data, t1, t2, phi) does not cover the "
f"full ({n_phi}, {n_t}, {n_t}) grid; cannot reconstruct stacked "
"C2 matrices."
)
del seen
order = np.argsort(cell_ids, kind="stable")
inferred_dt = float(np.median(np.diff(t_unique))) if n_t > 1 else 1.0
dt_value = float(dt) if dt is not None else inferred_dt
t_start_value = float(t_unique[0])
# Local imports avoid the import cycle that would arise from importing
# heterodyne.core.heterodyne_model at module top (HeterodyneModel
# transitively depends on optimization.cmc through other code paths).
from heterodyne.config.parameter_manager import ParameterManager
from heterodyne.config.parameter_space import ParameterSpace
from heterodyne.core.heterodyne_model import HeterodyneModel
from heterodyne.core.models import TwoComponentModel
from heterodyne.core.physics_factors import create_physics_factors
from heterodyne.core.scaling_utils import PerAngleScaling, ScalingConfig
space = parameter_space if parameter_space is not None else ParameterSpace()
if initial_values:
for name, val in initial_values.items():
if name in space.values:
space.values[name] = float(val)
param_manager = ParameterManager(space=space)
factors = create_physics_factors(
n_times=n_t,
dt=dt_value,
q=float(q),
phi_angle=0.0,
t_start=t_start_value,
)
scaling_values = getattr(space, "scaling_values", None) or {}
scaling = PerAngleScaling.from_config(
ScalingConfig(
n_angles=1,
mode="constant",
initial_contrast=float(scaling_values.get("contrast", 1.0)),
initial_offset=float(scaling_values.get("offset", 1.0)),
)
)
model = HeterodyneModel(
_model=TwoComponentModel(),
param_manager=param_manager,
_factors=factors,
scaling=scaling,
_t=factors.t,
)
nlsq_obj: NLSQResult | None = None
if isinstance(nlsq_result, NLSQResult):
nlsq_obj = nlsq_result
elif nlsq_result is not None:
logger.warning(
"fit_mcmc_jax: nlsq_result is %s, not NLSQResult; warm-start disabled.",
type(nlsq_result).__name__,
)
# Route through the joint multi-phi engine — mirrors homodyne's
# fit_mcmc_jax_impl, which always runs ONE NUTS pass over pooled data.
# Single-phi (n_phi=1) is a degenerate case of the same path.
nlsq_results_list: list[NLSQResult] | None = (
[nlsq_obj] * n_phi if nlsq_obj is not None else None
)
# Re-express each pooled point on the model's regular grid by index
# (regrid-by-index, exactly as the old dense round-trip did) and feed the
# flat arrays straight into the shared pooled engine — no dense
# (n_phi, n_t, n_t) reconstruction, no meshgrid rebuild.
model_t = np.asarray(model.t, dtype=np.float64)
prepared = prepare_mcmc_data(
data_arr[order],
model_t[i1[order]],
model_t[i2[order]],
unique_phi[phi_indices[order]],
)
return _fit_cmc_pooled(model, prepared, model_t, config, nlsq_results_list)
[docs]
def run_cmc_analysis(
model: HeterodyneModel,
c2_data: np.ndarray | jnp.ndarray,
config: CMCConfig | None = None,
**kwargs: Any,
) -> CMCResult:
"""Convenience wrapper around :func:`fit_cmc_jax` (homodyne parity).
Accepts the same arguments as :func:`fit_cmc_jax` and delegates directly.
"""
return fit_cmc_jax(model, c2_data, config=config, **kwargs)