CMC (Bayesian)

Bayesian posterior sampling via NumPyro NUTS, with NLSQ-derived warm-start initialization, configurable priors, reparameterization, and ArviZ-compatible convergence diagnostics.

Core Fitting

Core CMC fitting functions for heterodyne Bayesian analysis.

Includes the original single-run fit_cmc_jax and the new sharded Consensus Monte Carlo entry point fit_cmc_sharded, plus all supporting helpers for shard creation, prior tempering, and posterior combination.

heterodyne.optimization.cmc.core.CMC_F0_DEGEN_THRESHOLD: float = 0.1

Sample fraction below which alpha_sample/D0_sample/D_offset_sample are unidentifiable, causing 100% shard failure (het_bb97531f failure mode).

heterodyne.optimization.cmc.core.CMC_ALPHA_SINGULARITY: float = -1.5

alpha_sample value at which J_sample(t) ∝ t^α has a non-integrable singularity at t→0, collapsing NUTS step-size for the sample group.

heterodyne.optimization.cmc.core.CMC_F0_SAFE_ZONE: float = 0.2

Target value for f0 when auto-clamping a degenerate warm-start. Set to 2× the degeneracy threshold (not just above it) so NUTS leapfrog steps cannot easily reflect back into the unidentifiable region. The borderline value CMC_F0_DEGEN_THRESHOLD + 0.01 = 0.11 used in het_a10cf27e produced 47/47 shard failure even after clamping; this wider margin is the het_a10cf27e fix.

heterodyne.optimization.cmc.core.CMC_ALPHA_SAFE_ZONE: float = -1.0

Target value for alpha_sample when auto-clamping. Sits 0.5 above the t^α singularity (vs the previous 0.1 borderline value). The wider margin keeps NUTS step-size adaptation stable when warm-starting from the degenerate region.

heterodyne.optimization.cmc.core.fit_cmc_jax(model, c2_data, phi_angle=0.0, config=None, sigma=None, nlsq_result=None, t_override=None, priors_override=None, prior_width_multiplier=1.0)[source]

Fit heterodyne model using Consensus Monte Carlo.

Uses NumPyro’s NUTS sampler for Bayesian posterior inference.

Parameters:
  • model (HeterodyneModel) – HeterodyneModel with configured parameters

  • c2_data (ndarray | Array) – Observed correlation data

  • phi_angle (float) – Detector phi angle (degrees)

  • config (CMCConfig | None) – CMC configuration (default if None)

  • sigma (ndarray | float | None) – Measurement uncertainty (estimated if None)

  • nlsq_result (NLSQResult | None) – Optional NLSQ result for warm-starting

  • t_override (ndarray | None) – Optional time array replacing model.t for model construction. Used by fit_cmc_sharded to pass shard time slices. If None, falls back to model.t.

  • priors_override (dict | None) – Optional dict of pre-built NumPyro distributions keyed by parameter name. When provided, these distributions replace the default space.priors for matching parameters. Used by fit_cmc_sharded to inject tempered shard priors into the non-reparam model path.

  • prior_width_multiplier (float) – Scalar multiplier applied to the scale of each reparam-path prior AFTER nlsq_prior_width_factor scaling. Default 1.0 (no change). Used by fit_cmc_sharded to widen reparam priors by sqrt(K).

Return type:

CMCResult

Returns:

CMCResult with posterior samples and diagnostics

heterodyne.optimization.cmc.core.fit_cmc_sharded(model, c2_data, phi_angle=0.0, config=None, sigma=None, nlsq_result=None, num_shards=4, sharding_strategy='random', shard_seed=None)[source]

Fit heterodyne model using sharded Consensus Monte Carlo.

Splits the observed c2 matrix into num_shards independent data subsets, runs NUTS on each shard sub-posterior (sequentially), then combines the shard posteriors via inverse-variance weighted consensus.

Prior tempering is applied automatically: each shard’s prior distribution is widened by sqrt(num_shards) (i.e., prior^(1/K)) while sigma is passed unscaled. This is the correct Consensus Monte Carlo approach (Scott et al., 2016).

Parameters:
  • model (HeterodyneModel) – HeterodyneModel with configured parameters.

  • c2_data (ndarray | Array) – Observed two-time correlation matrix (N x N).

  • phi_angle (float) – Detector phi angle (degrees).

  • config (CMCConfig | None) – CMC configuration (defaults to CMCConfig()).

  • sigma (ndarray | float | None) – Measurement uncertainty (estimated if None).

  • nlsq_result (NLSQResult | None) – Optional NLSQ result for warm-starting each shard.

  • num_shards (int) – Number of data shards (K). Must be >= 2.

  • sharding_strategy (str) – One of "random" (default) or "contiguous". Random sharding breaks temporal autocorrelation between shards. Contiguous sharding uses diagonal time-blocks, which preserves the two-time structure within each shard.

  • shard_seed (int | None) – Integer seed for deterministic shard assignment. If None, a random seed is drawn from the OS.

Return type:

CMCResult

Returns:

CMCResult with combined posterior and per-shard diagnostics stored in result.metadata["shard_diagnostics"].

Raises:

ValueError – If inputs fail validation or num_shards < 2.

heterodyne.optimization.cmc.core.fit_cmc_multi_phi(model, c2_data, phi_angles, config=None, nlsq_results=None, sigma=None)[source]

Joint multi-phi CMC entry point (homodyne parity).

Fits the pooled multi-phi data with shared 14 physics parameters and per-angle contrast / offset. Mirrors homodyne’s _fit_mcmc_jax_impl: small datasets run a single NUTS pass; large datasets (n_total above the single-shard limit, or when num_shards / max_points_per_shard is set explicitly) are sharded and combined by Consensus Monte Carlo — NUTS is O(n) per leapfrog step, so a single pass over millions of pooled points is intractable. Algorithm:

  1. Pool c2_data (shape (n_phi, N, N) or (N, N) for n_phi=1) into flat arrays (data, t1, t2, phi) of length n_phi * N * N.

  2. Run prepare_mcmc_data() to filter the diagonal and build a PooledCMCData container with phi_unique and phi_indices (homodyne layout).

  3. Compute per-point grid indices i1_indices / i2_indices via searchsorted against model.t.

  4. Build the joint NumPyro model xpcs_model_heterodyne_scaled() (per-angle sampled scaling + shared physics + single pooled likelihood with t=0 boundary mask).

  5. Run NUTS with the configured chains / warmup / samples.

  6. Return a single CMCResult with shared-physics posterior + per-angle scaling posteriors (mean_contrast / mean_offset arrays of length n_phi).

The returned CMCResult reflects one joint inference — every angle contributes to the same physics-parameter posterior, exactly as in homodyne.

Parameters:
  • model (HeterodyneModel) – Configured HeterodyneModel whose time grid model.t defines the (N,) axis the pooled c2 was flattened from.

  • c2_data (ndarray | Array) – Experimental c2 of shape (n_phi, N, N) (multi-angle) or (N, N) (single-angle, treated as n_phi=1).

  • phi_angles (ndarray | list[float]) – Detector phi angles in degrees, length n_phi.

  • config (CMCConfig | None) – CMCConfig. None uses defaults.

  • nlsq_results (list[NLSQResult] | None) – Optional per-angle NLSQ warm-start. Currently only used to log warm-start status; future phases will translate to init_to_value.

  • sigma (ndarray | float | None) – Optional measurement uncertainty estimate; None triggers a MAD-based estimate via prepare_mcmc_data().

Returns:

Joint multi-phi result. parameter_names lists the 14 physics parameters followed by contrast_0..contrast_{n_phi-1} and offset_0..offset_{n_phi-1}.

Return type:

CMCResult

heterodyne.optimization.cmc.core.fit_mcmc_jax(data, t1, t2, phi, q, L, analysis_mode, method='mcmc', cmc_config=None, initial_values=None, parameter_space=None, dt=None, output_dir=None, progress_bar=True, run_id=None, nlsq_result=None, **kwargs)[source]

Homodyne-parity entry point for heterodyne CMC.

Mirrors homodyne.optimization.cmc.fit_mcmc_jax’s pooled-array call signature and routes to heterodyne’s native fit_cmc_jax / fit_cmc_sharded. This adapter exists so cross-package CLI / driver code that follows homodyne’s pooled-data convention can call heterodyne’s CMC pipeline without reshaping by hand.

Heterodyne’s native API is fit_cmc_jax(model, c2_data, phi_angle, config, ...) — new heterodyne code should prefer that directly.

Parameters:
  • data (ndarray) – Pooled C2 values and time/angle coordinates, all shape (n_total,). (t1, t2) must be a flattened regular meshgrid for the reverse reshape to succeed.

  • t1 (ndarray) – Pooled C2 values and time/angle coordinates, all shape (n_total,). (t1, t2) must be a flattened regular meshgrid for the reverse reshape to succeed.

  • t2 (ndarray) – Pooled C2 values and time/angle coordinates, all shape (n_total,). (t1, t2) must be a flattened regular meshgrid for the reverse reshape to succeed.

  • phi (ndarray) – Pooled C2 values and time/angle coordinates, all shape (n_total,). (t1, t2) must be a flattened regular meshgrid for the reverse reshape to succeed.

  • q (float) – Wavevector magnitude (Å⁻¹).

  • L (float) – Stator-rotor gap. Accepted for homodyne parity; heterodyne’s physics model uses absolute time scaling, not L-normalised dimensionless time.

  • analysis_mode (str) – Accepted for homodyne parity; heterodyne always uses its two-component model regardless of this value.

  • method (str) – Accepted for homodyne parity; consumed by heterodyne’s native pipeline where applicable.

  • output_dir (Any | None) – Accepted for homodyne parity; consumed by heterodyne’s native pipeline where applicable.

  • progress_bar (bool) – Accepted for homodyne parity; consumed by heterodyne’s native pipeline where applicable.

  • run_id (str | None) – Accepted for homodyne parity; consumed by heterodyne’s native pipeline where applicable.

  • cmc_config (dict[str, Any] | CMCConfig | None) – CMC configuration. Dicts are converted via CMCConfig.from_dict.

  • initial_values (dict[str, float] | None) – Initial parameter values applied to the constructed model.

  • parameter_space (Any | None) – Pre-built ParameterSpace. When None, a default one is built from DEFAULT_REGISTRY.

  • dt (float | None) – Time step. When None, inferred from np.diff(np.unique(t1 t2)).

  • nlsq_result (NLSQResult | dict | None) – Optional NLSQ warm-start. Dicts are ignored with a warning since heterodyne’s native warm-start requires an NLSQResult instance.

Returns:

Result of the CMC fit.

Return type:

CMCResult

Raises:
  • ValueError – If input shapes mismatch or the (t1, t2) pooled grid is not recoverable.

  • NotImplementedError – If pooled data contains multiple distinct phi angles; call this adapter once per angle, or use heterodyne.cli.optimization_runner.run_cmc for orchestrated multi-angle CMC.

heterodyne.optimization.cmc.core.run_cmc_analysis(model, c2_data, config=None, **kwargs)[source]

Convenience wrapper around fit_cmc_jax() (homodyne parity).

Accepts the same arguments as fit_cmc_jax() and delegates directly.

Return type:

CMCResult

Configuration

Configuration for CMC (Consensus Monte Carlo) analysis.

This module defines CMCConfig, a comprehensive dataclass covering all aspects of the heterodyne CMC pipeline: sharding strategy, backend selection, NUTS sampling parameters, convergence validation thresholds, reparameterization, prior tempering, shard combination, and run identification.

The heterodyne model has 14 free parameters (vs. 7 in homodyne). All auto-scaling formulas account for this increased dimensionality.

heterodyne.optimization.cmc.config.DENSE_MASS_WARMUP_FLOOR: int = 1500

NUTS warmup floor when dense_mass=True. Dense mass-matrix adaptation needs ≥ 100 steps per dimension; for the 14-parameter heterodyne model, 1500 warmup steps is the minimum that yields healthy R-hat in regression.

Type:

Rule 12

heterodyne.optimization.cmc.config.effective_warmup_floor(requested, *, dense_mass, fast_warmup=False)[source]

Return the safe warmup count, applying the Rule 12 floor when needed.

When dense_mass=True the NUTS dense mass-matrix adaptation requires at least DENSE_MASS_WARMUP_FLOOR steps; otherwise the requested value is returned unchanged. fast_warmup=True opts out of the floor and emits a one-shot warning — intended for CI fast-mode only.

Parameters:
  • requested (int) – Caller-supplied (possibly scaled) warmup count.

  • dense_mass (bool) – Whether the sampler is configured with dense_mass=True.

  • fast_warmup (bool) – Skip the Rule-12 floor (logs a one-shot warning).

Return type:

int

Returns:

Warmup count to use.

class heterodyne.optimization.cmc.config.CMCConfig[source]

Bases: object

Comprehensive configuration for Consensus Monte Carlo (CMC) analysis.

CMC splits a large dataset into K shards, runs NUTS independently on each shard, then combines the resulting posteriors using a consensus algorithm. This dataclass controls every knob across the full pipeline.

Parameters are grouped into logical sections, matching the structure of the to_dict / from_dict serialization format:

  • enable — master on/off switch and dataset-size gate.

  • per_angle — how to handle the phi (angle) dimension.

  • sharding — shard count, strategy, and size bounds.

  • backend_config — worker backend and checkpoint settings.

  • per_shard_mcmc — NUTS hyper-parameters and adaptive scaling.

  • validation — convergence thresholds and abort conditions.

  • nlsq — NLSQ warm-start and prior-width configuration.

  • prior_tempering — scale priors by 1/K for shard consistency.

  • combination — posterior combination algorithm and success criteria.

  • timeout — per-shard and heartbeat time limits.

  • reparameterization — parameter transforms and bimodality guards.

  • run_id — optional identifier for checkpoint namespacing.

enable

Master switch. "auto" enables CMC when n_points >= min_points_for_cmc. True / "always" forces CMC regardless of dataset size. False / "never" disables CMC entirely.

min_points_for_cmc

Minimum number of data points required before CMC is activated under enable="auto". Below this threshold the pipeline falls back to full NUTS.

per_angle_mode

Strategy for handling the angle (phi) dimension. "auto" selects automatically based on n_phi and constant_scaling_threshold.

constant_scaling_threshold

Minimum number of phi angles required before switching from "constant" to "individual" mode when per_angle_mode="auto".

sharding_strategy

How to partition data across shards: "stratified" preserves angle distributions, "random" shuffles globally, "contiguous" uses contiguous memory blocks.

num_shards

Number of shards K. "auto" derives K from dataset size, phi count, and min_points_per_shard / min_points_per_param.

max_points_per_shard

Upper bound on shard size. "auto" disables the cap.

min_points_per_shard

Lower bound on shard size; prevents degenerate under-determined shards.

min_points_per_param

Minimum ratio of points-to-parameters per shard (heterodyne default: 14).

backend_name

Worker backend. "auto" selects based on available CPU devices and core count.

enable_checkpoints

Persist intermediate shard results to checkpoint_dir.

checkpoint_dir

Directory for shard checkpoint files.

chain_method

Whether to run chains in "parallel" or "sequential" order within each shard worker.

num_warmup

Number of NUTS warm-up (burn-in) steps per chain. Default 1500 provides ~100 steps per parameter for the 14-parameter heterodyne model when dense_mass=True; lower values leave mass-matrix adaptation incomplete and produce high R-hat / divergence storms.

num_samples

Number of posterior draws per chain after warm-up.

num_chains

Number of independent MCMC chains per shard.

target_accept_prob

Target acceptance probability for the dual-averaging NUTS step-size adaptation (must be in [0.5, 0.99]). Default 0.90 keeps step sizes small enough to traverse the (D0, alpha) funnel without divergence cascades; reduce only if you understand the geometry.

max_tree_depth

Maximum binary tree depth for NUTS leapfrog integration.

seed

Base random seed for deterministic reproducibility.

dense_mass

Use a dense (full-covariance) mass matrix. Expensive but more accurate for highly correlated posteriors.

init_strategy

NUTS initialisation strategy.

adaptive_sampling

Scale num_warmup / num_samples down proportionally when shard size is below _REFERENCE_SHARD_SIZE.

min_warmup

Floor on adaptive warm-up count.

min_samples

Floor on adaptive sample count.

max_r_hat

Maximum acceptable Gelman-Rubin statistic; chains with R-hat > max_r_hat are flagged as not converged.

min_ess

Minimum effective sample size per parameter.

min_bfmi

Minimum Bayesian Fraction of Missing Information (energy diagnostic).

max_divergence_rate

Maximum fraction of divergent transitions before a shard is rejected.

require_nlsq_warmstart

Abort if an NLSQ warm-start was requested but unavailable.

allow_degenerate_warmstart

When False (default) an early RuntimeError is raised before dispatching any shards if the warm-start is in a regime that is guaranteed to cause 100% shard failure (f0 < 0.10 or alpha_sample < -1.5 — the het_bb97531f failure mode). Set to True to bypass the abort and let NUTS attempt the run anyway; useful when you want to see exactly how bad the posteriors are or when you have increased num_warmup substantially (≥ 2000).

max_parameter_cv

Maximum allowed coefficient of variation across chains for any parameter; guards against pathological multi-modal posteriors.

heterogeneity_abort

Abort the entire CMC run if shards produce incompatible posteriors (detected via KL divergence or parameter-CV checks).

use_nlsq_warmstart

Initialise each shard’s NUTS chains from the NLSQ MAP estimate.

use_nlsq_informed_priors

Centre Gaussian priors on NLSQ estimates scaled by nlsq_prior_width_factor.

nlsq_prior_width_factor

Scale factor applied to NLSQ parameter uncertainties when constructing informed priors.

prior_tempering

Divide log-prior by K (number of shards) so that the combined posterior approximates the full-data prior exactly.

combination_method

Algorithm used to combine shard posteriors.

min_success_rate

Minimum fraction of shards that must converge; run fails below this.

min_success_rate_warning

Fraction below which a warning is emitted even if the run succeeds.

per_shard_timeout

Wall-clock seconds allowed per shard before it is cancelled.

heartbeat_timeout

Seconds of silence from a worker before it is declared dead.

use_reparam

Apply parameter reparameterisations (e.g. log-transforms) in NumPyro.

reparameterization_d_total

Reparameterise d_total = d_fast + d_slow as an unconstrained sum.

reparameterization_log_gamma

Reparameterise gamma on a log scale to enforce positivity.

bimodal_min_weight

Minimum mixture weight for the minor mode in bimodal posteriors; below this the minor mode is discarded.

bimodal_min_separation

Minimum normalised distance between modes to declare bimodality.

run_id

Optional string identifier for this CMC run, used in checkpoint paths and log messages.

enable: bool | str = 'auto'
min_points_for_cmc: int = 100000
per_angle_mode: str = 'auto'
constant_scaling_threshold: int = 3
sharding_strategy: str = 'random'
num_shards: int | str = 'auto'
max_points_per_shard: int | str = 'auto'
min_points_per_shard: int = 10000
min_points_per_param: int = 1500
backend_name: str = 'auto'
enable_checkpoints: bool = True
checkpoint_dir: str = './checkpoints/cmc'
chain_method: str = 'parallel'
num_warmup: int = 1500
num_samples: int = 1500
num_chains: int = 4
target_accept_prob: float = 0.9
max_tree_depth: int = 10
seed: int = 42
dense_mass: bool = True
init_strategy: str = 'init_to_median'
adaptive_sampling: bool = True
min_warmup: int = 100
min_samples: int = 200
fast_warmup: bool = False

skip the dense-mass warmup floor (1500 steps). Intended for CI fast-mode and pytest fixtures only; NOT for production. When True, all warmup calculations bypass effective_warmup_floor().

Type:

Rule 12 escape hatch

max_r_hat: float = 1.1
min_ess: int = 400
min_bfmi: float = 0.3
max_divergence_rate: float = 0.1
require_nlsq_warmstart: bool = False
allow_degenerate_warmstart: bool = False
max_parameter_cv: float = 1.0
heterogeneity_abort: bool = True
use_nlsq_warmstart: bool = True
use_nlsq_informed_priors: bool = True
nlsq_prior_width_factor: float = 2.0
use_log_space_priors: bool = True

integrate build_log_space_priors into build_default_priors. When True (default), parameters flagged log_space=True in the parameter registry (currently D0_ref, D0_sample, v0) are sampled with LogNormal priors instead of TruncatedNormal — better mass-matrix conditioning for prefactors that span several orders of magnitude. When False, the registry’s log_space flag is ignored and all parameters use TruncatedNormal. The reparameterized path (use_reparam=True) is unaffected — it samples log_X_at_tref directly and does not reach this code path.

Type:

Codex S1

prior_tempering: bool = True
combination_method: str = 'robust_consensus_mc'
min_success_rate: float = 0.9
min_success_rate_warning: float = 0.8
per_shard_timeout: int = 7200
heartbeat_timeout: int = 600
enable_jax_profiling: bool = False
jax_profile_dir: str = './profiles/jax'
use_reparam: bool = True
reparameterization_d_total: bool = True
reparameterization_log_gamma: bool = True
bimodal_min_weight: float = 0.2
bimodal_min_separation: float = 0.5
run_id: str | None = None
__post_init__()[source]

Normalise string-valued enable flag and log construction.

Return type:

None

validate()[source]

Run comprehensive field validation and return a list of error strings.

Returns:

Empty list when the configuration is valid; one entry per violation otherwise. Does not raise — callers decide how to handle errors.

Return type:

list[str]

is_valid()[source]

Return True if the configuration passes all validation checks.

Equivalent to len(self.validate()) == 0.

Return type:

bool

should_enable_cmc(n_points, analysis_mode=None)[source]

Decide whether to run CMC given the dataset size.

Parameters:
  • n_points (int) – Total number of data points in the dataset.

  • analysis_mode (str | None) – Optional homodyne-parity kwarg. Accepted and ignored — heterodyne does not gate CMC on the analyzer mode. Present so that callers ported from homodyne continue to work.

Returns:

True if CMC should run for this dataset.

Return type:

bool

get_num_shards(n_points, n_phi, n_params=_N_PARAMS_HETERODYNE)[source]

Compute the number of shards K for a given dataset.

When num_shards is an explicit integer it is returned directly (clamped to >= 1). When "auto", K is derived as:

  1. Start from max(n_phi, 2) — at least as many shards as phi angles.

  2. Apply the min_points_per_shard lower bound: K <= n_points // min_points_per_shard.

  3. Apply the min_points_per_param constraint: K <= n_points // (n_params * min_points_per_param).

  4. Apply the max_points_per_shard upper bound when set: K >= ceil(n_points / max_points_per_shard).

  5. Clamp to [1, n_points].

Parameters:
  • n_points (int) – Total number of data points.

  • n_phi (int) – Number of distinct phi (azimuthal angle) bins.

  • n_params (int) – Number of free model parameters (default: 14 for heterodyne).

Returns:

Number of shards K >= 1.

Return type:

int

get_adaptive_sample_counts(shard_size, n_params=_N_PARAMS_HETERODYNE)[source]

Scale warmup and sample counts for a given shard size.

When adaptive_sampling=False the configured num_warmup and num_samples are returned unchanged.

The scaling law is:

scale = clamp(shard_size / reference_size, 0, 1)
warmup = max(min_warmup, round(num_warmup * scale))
samples = max(min_samples, round(num_samples * scale))

where reference_size = _REFERENCE_SHARD_SIZE (10 000 points) is the shard size at which the full configured counts are used. Larger shards are not scaled up beyond the configured maximum; the formula saturates at scale = 1.

A secondary check ensures a minimum of n_params samples are drawn (ESS cannot exceed num_samples * num_chains).

Parameters:
  • shard_size (int) – Number of data points in this shard.

  • n_params (int) – Number of model parameters (default: 14).

Returns:

(warmup, samples) after adaptive scaling.

Return type:

tuple[int, int]

get_effective_per_angle_mode(n_phi, nlsq_per_angle_mode=None, has_nlsq_warmstart=False)[source]

Resolve the effective per-angle mode for a concrete dataset.

Mirrors homodyne/optimization/cmc/config.py::get_effective_per_angle_mode so CMC/NLSQ parameterization stays in lock-step across both packages.

Resolution logic (priority order):

  1. If nlsq_per_angle_mode is provided, mirror it for CMC↔NLSQ parameterization parity, regardless of self.per_angle_mode. If both sides are "auto" AND has_nlsq_warmstart is True, promote to "constant_averaged" so scaling is fixed (fewer sampled params, less heterogeneity across shards).

  2. Else if self.per_angle_mode != "auto" → return it directly.

  3. Else (auto, no NLSQ): n_phi >= constant_scaling_threshold"auto" (sampled averaged); otherwise → "individual".

Parameters:
  • n_phi (int) – Number of distinct phi (azimuthal angle) bins in the dataset.

  • nlsq_per_angle_mode (str | None) – The per-angle mode resolved by the preceding NLSQ fit, if any. When provided this overrides the configured mode for parity.

  • has_nlsq_warmstart (bool) – Whether a valid NLSQ warm-start is available for this run.

Returns:

Effective mode: "auto", "constant", "constant_averaged", or "individual".

Return type:

str

to_dict()[source]

Serialise the configuration to a nested dictionary.

The returned structure uses the same section names expected by from_dict, making round-trips lossless.

Returns:

Nested dictionary representation of the config.

Return type:

dict[str, Any]

classmethod from_dict(config_dict)[source]

Construct a CMCConfig from a (possibly nested) dictionary.

Recognised top-level keys and sections:

  • enable, min_points_for_cmc, run_id — top-level scalars.

  • prior_tempering — top-level scalar.

  • per_angle — maps to per_angle_mode, constant_scaling_threshold.

  • sharding — maps to the five sharding fields.

  • backend_config — maps to the four backend fields.

  • per_shard_mcmc — maps to the eleven sampling fields.

  • validation — maps to the seven validation-threshold fields.

  • nlsq — maps to the three NLSQ-prior fields.

  • combination — maps to the three combination fields.

  • timeout — maps to the two timeout fields.

  • reparameterization — maps to the five reparam fields.

Flat (non-nested) dictionaries are also accepted for backward compatibility: any key that matches a field name directly is used as-is.

Unrecognised top-level keys emit a warnings.warn so that configuration typos surface immediately.

Parameters:

config_dict (dict[str, Any]) – Parsed YAML / JSON dictionary.

Returns:

Fully constructed configuration instance.

Return type:

CMCConfig

__init__(enable='auto', min_points_for_cmc=100000, per_angle_mode='auto', constant_scaling_threshold=3, sharding_strategy='random', num_shards='auto', max_points_per_shard='auto', min_points_per_shard=10000, min_points_per_param=1500, backend_name='auto', enable_checkpoints=True, checkpoint_dir='./checkpoints/cmc', chain_method='parallel', num_warmup=1500, num_samples=1500, num_chains=4, target_accept_prob=0.9, max_tree_depth=10, seed=42, dense_mass=True, init_strategy='init_to_median', adaptive_sampling=True, min_warmup=100, min_samples=200, fast_warmup=False, max_r_hat=1.1, min_ess=400, min_bfmi=0.3, max_divergence_rate=0.1, require_nlsq_warmstart=False, allow_degenerate_warmstart=False, max_parameter_cv=1.0, heterogeneity_abort=True, use_nlsq_warmstart=True, use_nlsq_informed_priors=True, nlsq_prior_width_factor=2.0, use_log_space_priors=True, prior_tempering=True, combination_method='robust_consensus_mc', min_success_rate=0.9, min_success_rate_warning=0.8, per_shard_timeout=7200, heartbeat_timeout=600, enable_jax_profiling=False, jax_profile_dir='./profiles/jax', use_reparam=True, reparameterization_d_total=True, reparameterization_log_gamma=True, bimodal_min_weight=0.2, bimodal_min_separation=0.5, run_id=None, _validation_errors=<factory>)

Note

Key attribute names (renamed from legacy): target_accept_prob, max_r_hat, nlsq_prior_width_factor. The from_dict() class method handles legacy key translation.

Results

Result container for CMC analysis.

class heterodyne.optimization.cmc.results.ParameterStats[source]

Bases: dict

Hybrid mapping/sequence for posterior summaries.

Supports dict-style access by name (ps["D0_ref"]) and integer-index access (ps[0]). Inherits dict so existing .get() / in checks continue to work unchanged.

__init__(ordered_names, values)[source]
property as_array: ndarray

Ordered values as a numpy array.

tolist()[source]
Return type:

list[float]

class heterodyne.optimization.cmc.results.CMCResult[source]

Bases: object

Result of CMC (Consensus Monte Carlo) analysis.

Contains posterior samples, summaries, and convergence diagnostics.

parameter_names: list[str]
posterior_mean: ndarray
posterior_std: ndarray
credible_intervals: dict[str, dict[str, float]]
convergence_passed: bool
r_hat: ndarray | None = None
ess_bulk: ndarray | None = None
ess_tail: ndarray | None = None
bfmi: list[float] | None = None
samples: dict[str, ndarray] | None = None
map_estimate: ndarray | None = None
num_warmup: int = 0
num_samples: int = 0
num_chains: int = 0
num_shards: int = 1
divergences: int = 0
wall_time_seconds: float | None = None
metadata: dict[str, Any]
convergence_status: str = 'not_converged'
warmup_time: float | None = None
per_angle_mode: str = 'auto'
chi_squared: float | None = None
quality_flag: str | None = None
mean_contrast: ndarray | None = None
std_contrast: ndarray | None = None
mean_offset: ndarray | None = None
std_offset: ndarray | None = None
inference_data: Any | None = None
property n_params: int

Number of parameters.

property parameters: ndarray

Homodyne-parity alias for posterior_mean.

property uncertainties: ndarray

Homodyne-parity alias for posterior_std.

property param_names: list[str]

Homodyne-parity alias for parameter_names.

get_param_summary(name)[source]

Get summary statistics for a parameter.

Parameters:

name (str) – Parameter name

Return type:

dict[str, float]

Returns:

Dict with mean, std, and credible interval bounds

get_samples(name)[source]

Get posterior samples for a parameter.

Parameters:

name (str) – Parameter name

Return type:

ndarray | None

Returns:

Array of samples or None if not stored

params_dict()[source]

Get posterior means as dictionary.

Return type:

dict[str, float]

validate_convergence(r_hat_threshold=1.1, min_ess=100, min_bfmi=0.3)[source]

Validate convergence diagnostics.

Parameters:
  • r_hat_threshold (float) – Maximum acceptable R-hat

  • min_ess (int) – Minimum effective sample size

  • min_bfmi (float) – Minimum BFMI value

Return type:

list[str]

Returns:

List of warning messages

summary()[source]

Generate summary string.

Return type:

str

get_samples_array()[source]

Return samples as a 3-D array of shape (num_chains, num_samples, n_params).

Parameters with missing or None samples are filled with zeros. Flat 1-D sample arrays of shape num_chains * num_samples are reshaped to (num_chains, num_samples) automatically.

Return type:

ndarray

get_posterior_stats()[source]

Return per-parameter posterior statistics.

Returns a dict keyed by parameter name. Each value contains mean, std, median, hdi_5%, hdi_95%, r_hat, ess_bulk, ess_tail.

Parameters absent from self.samples are omitted.

Return type:

dict[str, dict[str, float]]

classmethod from_mcmc_samples(mcmc_samples, stats, analysis_mode='static', n_warmup=500, min_ess=None)[source]

Build a CMCResult from raw MCMC samples (homodyne parity).

Mirrors homodyne.optimization.cmc.results.CMCResult.from_mcmc_samples. Duck-typed: mcmc_samples must expose .samples (dict[str, ndarray]), .param_names (list[str]), .n_chains (int), .n_samples (int); stats must expose .num_divergent (int) and may expose .wall_time / .warmup_time.

Diagnostics (R-hat, ESS) are not computed here — they require per-chain reshaping and ArviZ. Callers that need diagnostics should run cmc_result_to_arviz() and overwrite .r_hat / .ess_* after construction.

Parameters:
  • mcmc_samples (Any) – Object holding posterior draws; see duck-typed surface above.

  • stats (Any) – Sampling statistics object; see duck-typed surface above.

  • analysis_mode (str) – Stored on the result for plot / report consumers. Default "static" mirrors homodyne.

  • n_warmup (int) – Number of warmup draws (recorded on the result).

  • min_ess (float | None) – Accepted for homodyne parity; ignored here because diagnostics are not computed by this factory.

Returns:

Populated result with parameter_names, posterior_mean, posterior_std, samples, and basic sampling/divergence metadata. credible_intervals is left empty; downstream consumers can populate it via get_posterior_stats.

Return type:

CMCResult

__init__(parameter_names, posterior_mean, posterior_std, credible_intervals, convergence_passed, r_hat=None, ess_bulk=None, ess_tail=None, bfmi=None, samples=None, map_estimate=None, num_warmup=0, num_samples=0, num_chains=0, num_shards=1, divergences=0, wall_time_seconds=None, metadata=<factory>, convergence_status='not_converged', warmup_time=None, per_angle_mode='auto', chi_squared=None, quality_flag=None, mean_contrast=None, std_contrast=None, mean_offset=None, std_offset=None, inference_data=None)
heterodyne.optimization.cmc.results.cmc_result_to_arviz(result)[source]

Convert a CMCResult to an ArviZ InferenceData object.

Samples stored in result.samples are reshaped to (num_chains, num_draws) when result.num_chains > 1 so that ArviZ can compute per-chain diagnostics (R-hat, ESS). When the result carries flat 1-D arrays the function treats the entire sequence as a single chain.

Parameters:

result (CMCResult) – Completed CMC analysis result.

Return type:

Any

Returns:

arviz.InferenceData with a posterior group populated from result.samples and, when available, sample_stats populated from result.bfmi.

Raises:
heterodyne.optimization.cmc.results.compare_cmc_nlsq(cmc_result, nlsq_result, consistency_sigma=2.0)[source]

Compare CMC posterior means with NLSQ point estimates.

Parameters that appear in both results are compared. Parameters present in only one result are silently skipped.

Parameters:
  • cmc_result (CMCResult) – Completed CMC result.

  • nlsq_result (Any) – Completed NLSQ result (NLSQResult instance).

  • consistency_sigma (float) – Number of posterior standard deviations within which the NLSQ estimate must fall to be flagged as consistent. Defaults to 2.0 (approximately 95 % credible interval).

Returns:

  • "common_parameters" — list of parameter names present in both.

  • "differences" — dict mapping name to (cmc_mean - nlsq_value).

  • "relative_deviations" — dict mapping name to abs(cmc_mean - nlsq_value) / cmc_std.

  • "consistent" — dict mapping name to bool (True if within consistency_sigma posterior std of the CMC mean).

  • "n_consistent" — int count of consistent parameters.

  • "n_inconsistent" — int count of inconsistent parameters.

  • "consistency_sigma" — the threshold used.

Return type:

dict[str, Any]

heterodyne.optimization.cmc.results.merge_shard_cmc_results(shard_results, parameter_names=None)[source]

Combine multiple shard CMCResults into a single consensus result.

Uses inverse-variance weighting (precision weighting) to combine posterior means from independent shards, following the Consensus Monte Carlo methodology (Scott et al., 2016). Diagnostics (R-hat, ESS, BFMI) are set to their worst-case values across shards so that failures are never hidden by averaging.

Parameters:
  • shard_results (list[CMCResult]) – Non-empty list of per-shard CMCResults. Each must have the same parameter_names (or parameter_names override must be supplied).

  • parameter_names (list[str] | None) – Optional explicit parameter name list. When supplied, only these parameters are included in the merged result; they must be present in every shard.

Return type:

CMCResult

Returns:

A new CMCResult representing the consensus posterior.

Raises:

ValueError – If shard_results is empty or parameter names are inconsistent across shards when no override is given.

heterodyne.optimization.cmc.results.cmc_result_summary_table(result, ci_level='95', width=80)[source]

Format a CMCResult as a human-readable parameter summary table.

The table includes columns for parameter name, posterior mean, posterior standard deviation, credible interval bounds, R-hat, and bulk ESS. Missing diagnostics are shown as N/A.

Parameters:
  • result (CMCResult) – Completed CMC analysis result.

  • ci_level (str) – Credible interval level to display. Must be "95" or "89". Defaults to "95".

  • width (int) – Total character width of the horizontal rule separators. Defaults to 80.

Return type:

str

Returns:

Multi-line string containing the formatted table.

Raises:

ValueError – If ci_level is not "95" or "89".

NumPyro Model

NumPyro model definition for heterodyne Bayesian inference.

heterodyne.optimization.cmc.model.xpcs_model_heterodyne_scaled(data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, space, num_shards=1)[source]

Joint multi-phi heterodyne CMC model (homodyne parity).

Mirrors homodyne.optimization.cmc.model.xpcs_model_scaled: ONE NUTS pass over pooled multi-phi data with shared 14 physics parameters and per-angle sampled contrast / offset. The likelihood site evaluates the Normal log-prob at every pooled point in a single numpyro.sample call, gather-by-phi-index.

Parameters:
  • data (Array) – Pooled C2 values, shape (n_total,) after diagonal filtering.

  • t (Array) – Unique time grid, shape (N,). Used by compute_c2_heterodyne.

  • q (float) – Physics scalars.

  • dt (float) – Physics scalars.

  • phi_unique (Array) – Sorted unique phi angles, shape (n_phi,).

  • phi_indices (Array) – Per-point index into phi_unique, shape (n_total,).

  • i1_indices (Array) – Per-point indices into t for the two time coordinates, shape (n_total,) each. Pre-computed via np.searchsorted(t, t1) / np.searchsorted(t, t2).

  • i2_indices (Array) – Per-point indices into t for the two time coordinates, shape (n_total,) each. Pre-computed via np.searchsorted(t, t1) / np.searchsorted(t, t2).

  • noise_scale (float) – Data-driven sigma prior centre (homodyne-parity HalfNormal scale = noise_scale * 1.5 * sqrt(num_shards)).

  • space (ParameterSpace) – Parameter space holding priors and initial values for the 14 physics parameters + 2 scaling.

  • num_shards (int) – Shard count for CMC sigma-prior tempering (Scott et al. 2016). Default 1 (no tempering).

Return type:

None

heterodyne.optimization.cmc.model.xpcs_model_heterodyne_constant(data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, space, fixed_contrast, fixed_offset, num_shards=1)[source]

Joint multi-phi CMC model with FIXED per-angle scaling.

Mirrors homodyne.optimization.cmc.model.xpcs_model_constant. Per-angle contrast and offset are passed in as arrays (length n_phi, typically derived from quantile estimation on the raw data) and NOT sampled — only the 14 physics params + sigma are sampled.

Return type:

None

heterodyne.optimization.cmc.model.xpcs_model_heterodyne_averaged(data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, space, num_shards=1)[source]

Joint multi-phi CMC model with SAMPLED averaged (single) scaling.

Mirrors homodyne.optimization.cmc.model.xpcs_model_averaged. A single contrast and a single offset are sampled and broadcast across all n_phi angles (cf. heterodyne per_angle_mode="auto" when the auto-resolver promotes to averaged scaling).

Return type:

None

heterodyne.optimization.cmc.model.xpcs_model_heterodyne_constant_averaged(data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, space, fixed_contrast, fixed_offset, num_shards=1)[source]

Joint multi-phi CMC model with FIXED averaged (single) scaling.

Mirrors homodyne.optimization.cmc.model.xpcs_model_constant_averaged. A single contrast and offset (typically the mean of the NLSQ per-angle estimates) are broadcast across all n_phi angles. No scaling parameters are sampled — only the 14 physics params + sigma.

Return type:

None

heterodyne.optimization.cmc.model.get_heterodyne_pooled_model_for_mode(per_angle_mode, *, data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, space, fixed_contrast=None, fixed_offset=None, num_shards=1)[source]

Dispatch to the joint multi-phi model variant for per_angle_mode.

Mirrors homodyne.optimization.cmc.model.get_xpcs_model at the pooled-data layer. Returns a zero-arg callable suitable for passing to NUTS.

Modes:

  • "individual" / "scaled"xpcs_model_heterodyne_scaled() (per-angle sampled contrast/offset).

  • "constant"xpcs_model_heterodyne_constant(). Requires fixed_contrast and fixed_offset as length-n_phi arrays.

  • "auto" / "averaged"xpcs_model_heterodyne_averaged() (single sampled averaged contrast/offset).

  • "constant_averaged"xpcs_model_heterodyne_constant_averaged(). Requires scalar fixed_contrast and fixed_offset.

Return type:

Callable[[], None]

heterodyne.optimization.cmc.model.get_heterodyne_model(t, q, dt, phi_angle, c2_data, noise_scale, space, contrast=1.0, offset=1.0, shard_grid=None, priors_override=None, num_shards=1)[source]

Create NumPyro model for heterodyne correlation fitting.

Sigma is sampled as a posterior variable via HalfNormal(noise_scale * 1.5 * sqrt(num_shards)), matching the homodyne parity convention so the posterior captures noise uncertainty.

Parameters:
  • t (Array) – Time array

  • q (float) – Wavevector

  • dt (float) – Time step

  • phi_angle (float) – Detector phi angle

  • c2_data (Array) – Observed correlation data — shape (N, N) for meshgrid path, or (n_pairs,) for element-wise path.

  • noise_scale (float) – Data-driven prior center for the measurement-uncertainty sigma posterior. Typically the mean / RMS of an external estimate from estimate_sigma().

  • space (ParameterSpace) – Parameter space with priors

  • contrast (float) – Speckle contrast (beta), default 1.0

  • offset (float) – Baseline offset, default 1.0

  • shard_grid (ShardGrid | None) – Optional pre-computed ShardGrid. When provided, uses the memory-efficient element-wise path (no N×N allocation). c2_data must then be flattened to match the shard grid’s paired indices.

  • priors_override (dict | None) – Optional dictionary mapping parameter names to NumPyro distributions. When provided, overrides the default space.priors[name] for any matching parameter name. Used by fit_cmc_sharded to inject tempered priors.

  • num_shards (int) – Number of CMC shards for sigma prior tempering. Widens the HalfNormal scale by sqrt(num_shards) so that the product across shards stays equivalent to the unsharded prior. Defaults to 1 (no tempering).

Returns:

NumPyro model function

heterodyne.optimization.cmc.model.get_heterodyne_model_reparam(t, q, dt, phi_angle, c2_data, noise_scale, space, nlsq_params=None, reparam_config=None, scalings=None, contrast=1.0, offset=1.0, shard_grid=None, num_shards=1)[source]

Create NumPyro model with reparameterization for better sampling.

When NLSQ result and reparameterization config are provided, uses: 1. Reference-time reparameterization for power-law pairs 2. Smooth bounded transforms (tanh) instead of jnp.clip() 3. NLSQ-informed priors with delta-method uncertainty propagation

Falls back to the original clip-based behavior when the new infrastructure is not provided (backward compatibility).

Sigma is sampled internally via HalfNormal(noise_scale * 1.5 * sqrt(num_shards)) to match homodyne CMC parity.

Parameters:
  • t (Array) – Time array

  • q (float) – Wavevector

  • dt (float) – Time step

  • phi_angle (float) – Detector phi angle

  • c2_data (Array) – Observed correlation data

  • noise_scale (float) – Data-driven prior center for sampled sigma.

  • space (ParameterSpace) – Parameter space

  • nlsq_params (Array | None) – Optional NLSQ fitted values for centering (legacy path)

  • reparam_config (ReparamConfig | None) – Reparameterization config (enables new path)

  • scalings (dict[str, ParameterScaling] | None) – Pre-computed ParameterScaling per reparam-space param

  • num_shards (int) – CMC shard count for sigma prior tempering. Default 1.

Returns:

NumPyro model function

heterodyne.optimization.cmc.model.get_heterodyne_model_constant(t, q, dt, phi_angle, c2_data, noise_scale, space, fixed_contrast, fixed_offset, shard_grid=None, num_shards=1)[source]

Create NumPyro model with FIXED (pre-computed) per-angle scaling.

Contrast and offset are not sampled — they are provided as fixed arrays from a preceding NLSQ or preprocessing step. Suitable for per_angle_mode="constant", where each angle has its own fixed scaling but the physical parameters are shared.

Sigma is sampled internally via HalfNormal(noise_scale * 1.5 * sqrt(num_shards)) for homodyne CMC parity.

Parameters:
  • t (Array) – Time array, shape (n_t,).

  • q (float) – Wavevector magnitude (Å⁻¹).

  • dt (float) – Lag-time step (s).

  • phi_angle (float) – Detector phi angle for this shard (degrees).

  • c2_data (Array) – Observed correlation data, shape (n_t,) or (n_phi, n_t).

  • noise_scale (float) – Data-driven prior center for sampled sigma.

  • space (ParameterSpace) – Parameter space carrying priors and fixed values.

  • fixed_contrast (Array) – Speckle contrast per angle, shape (n_phi,) or scalar.

  • fixed_offset (Array) – Baseline offset per angle, shape (n_phi,) or scalar.

  • num_shards (int) – CMC shard count for sigma prior tempering. Default 1.

Returns:

NumPyro model callable (no required arguments).

heterodyne.optimization.cmc.model.get_heterodyne_model_constant_averaged(t, q, dt, phi_angle, c2_data, noise_scale, space, mean_contrast, mean_offset, shard_grid=None, num_shards=1)[source]

Create NumPyro model with a single averaged scaling broadcast to all angles.

Both mean_contrast and mean_offset are scalars computed from the average over all phi angles. They are treated as fixed (not sampled) and broadcast uniformly. Suitable for per_angle_mode="constant_averaged".

Sigma is sampled internally via HalfNormal(noise_scale * 1.5 * sqrt(num_shards)) for homodyne CMC parity.

Parameters:
  • t (Array) – Time array, shape (n_t,).

  • q (float) – Wavevector magnitude (Å⁻¹).

  • dt (float) – Lag-time step (s).

  • phi_angle (float) – Detector phi angle for this shard (degrees).

  • c2_data (Array) – Observed correlation data.

  • noise_scale (float) – Data-driven prior center for sampled sigma.

  • space (ParameterSpace) – Parameter space carrying priors and fixed values.

  • mean_contrast (float) – Scalar speckle contrast averaged over all phi angles.

  • mean_offset (float) – Scalar baseline offset averaged over all phi angles.

  • num_shards (int) – CMC shard count for sigma prior tempering. Default 1.

Returns:

NumPyro model callable (no required arguments).

heterodyne.optimization.cmc.model.get_heterodyne_model_individual(t, q, dt, phi_angles, c2_data, noise_scale, space, contrast_prior_loc=0.5, contrast_prior_scale=0.25, offset_prior_loc=1.0, offset_prior_scale=0.25, shard_grids=None, num_shards=1)[source]

Create NumPyro model with per-angle sampled contrast and offset.

The most general per-angle model: independently samples contrast_i and offset_i for each phi angle using weakly informative Gaussian priors. Suitable for per_angle_mode="individual".

Physical parameters are shared across all angles; the per-angle scaling lives in a numpyro.plate over the angle dimension.

Parameters:
  • t (Array) – Time array, shape (n_t,).

  • q (float) – Wavevector magnitude (Å⁻¹).

  • dt (float) – Lag-time step (s).

  • phi_angles (Array) – Detector phi angles, shape (n_phi,).

  • c2_data (Array) – Observed correlation data, shape (n_phi, n_t).

  • sigma – Measurement uncertainty — scalar or shape (n_phi, n_t).

  • space (ParameterSpace) – Parameter space carrying priors and fixed values.

  • contrast_prior_loc (Array | float) – Prior centre(s) for contrast. Scalar or (n_phi,) array. Default 0.5.

  • contrast_prior_scale (float) – Prior width for contrast. Default 0.25.

  • offset_prior_loc (Array | float) – Prior centre(s) for offset. Scalar or (n_phi,) array. Default 1.0.

  • offset_prior_scale (float) – Prior width for offset. Default 0.25.

  • shard_grids (list[ShardGrid] | None) – Optional list of pre-computed ShardGrids, one per phi angle. When provided, uses the memory-efficient element-wise path (no N×N allocation per angle). c2_data[ai] and sigma[ai] must then be flattened to match each shard grid’s paired indices. Without this, the model builds n_phi N×N matrices per NUTS step which can cause OOM for large datasets.

Returns:

NumPyro model callable (no required arguments).

heterodyne.optimization.cmc.model.get_model_for_mode(per_angle_mode, t, q, dt, phi_angle, c2_data, noise_scale, space, nlsq_result=None, reparam_config=None, num_shards=1, **kwargs)[source]

Select and build the appropriate NumPyro model based on per-angle mode.

Factory that maps per_angle_mode strings to concrete model constructors. Extra keyword arguments are forwarded to the selected constructor, allowing callers to pass mode-specific parameters (e.g. fixed_contrast, mean_contrast, phi_angles) without branching at the call site.

Mapping

"auto"

Delegates to get_heterodyne_model() (sampled contrast/offset from the parameter space) or get_heterodyne_model_reparam() when reparam_config is supplied.

"constant"

Delegates to get_heterodyne_model_constant(). Requires fixed_contrast and fixed_offset in kwargs.

"constant_averaged"

Delegates to get_heterodyne_model_constant_averaged(). Requires mean_contrast and mean_offset in kwargs.

"individual"

Delegates to get_heterodyne_model_individual(). Requires phi_angles and c2_data shaped (n_phi, n_t) in kwargs.

type per_angle_mode:

str

param per_angle_mode:

One of "auto", "constant", "constant_averaged", "individual".

type t:

Array

param t:

Time array.

type q:

float

param q:

Wavevector magnitude (Å⁻¹).

type dt:

float

param dt:

Lag-time step (s).

type phi_angle:

float

param phi_angle:

Scalar phi angle (used by non-individual modes).

type c2_data:

Array

param c2_data:

Observed correlation data.

type noise_scale:

float

param noise_scale:

Data-driven prior centre for the sampled sigma site.

type space:

ParameterSpace

param space:

Parameter space.

type nlsq_result:

NLSQResult | None

param nlsq_result:

Optional NLSQ result for warm-starting (used by "auto" mode when reparam_config is supplied).

type reparam_config:

ReparamConfig | None

param reparam_config:

Optional reparameterization config. When provided alongside "auto" mode, activates the reparam model path.

type num_shards:

int

param num_shards:

CMC shard count for sigma prior tempering. Default 1.

type **kwargs:

object

param **kwargs:

Mode-specific keyword arguments forwarded verbatim.

rtype:

Callable[[], None]

returns:

NumPyro model callable (no required arguments).

raises ValueError:

If per_angle_mode is not a recognised string.

heterodyne.optimization.cmc.model.estimate_sigma(c2_data, method='diagonal', nlsq_result=None, n_bootstrap=200, bootstrap_seed=0)[source]

Estimate measurement uncertainty from data.

Supported methods:

  • "diagonal" – Uses the standard deviation of the diagonal of c2_data relative to its mean, floored at 1 % of the data’s overall scale. Fast and requires no additional information.

  • "constant" – Returns the overall standard deviation of c2_data as a scalar.

  • "local" – Computes a spatially smoothed local variance via scipy.ndimage.uniform_filter. Requires SciPy.

  • "residual" – Estimates sigma from the RMS of NLSQ residuals. Requires nlsq_result with a non-None residuals field. Falls back to "diagonal" if residuals are unavailable.

  • "bootstrap" – Draws n_bootstrap bootstrap replicates of the diagonal and returns the standard deviation of per-replicate means as the noise estimate. Useful when the diagonal has enough points to bootstrap.

Parameters:
  • c2_data (Array) – Correlation data, shape (n_t,) or (n_phi, n_t).

  • method (str) – Estimation method — one of "diagonal", "constant", "local", "residual", "bootstrap".

  • nlsq_result (NLSQResult | None) – NLSQ result object. Required (and used) only for method="residual".

  • n_bootstrap (int) – Number of bootstrap replicates for method="bootstrap". Default 200.

  • bootstrap_seed (int) – JAX PRNG seed for method="bootstrap". Default 0.

Return type:

Array

Returns:

Estimated sigma — same shape as c2_data for "local", scalar or (n_t,) array for all other methods.

Raises:

ValueError – If method is not a recognised string.

heterodyne.optimization.cmc.model.validate_model_output(c2_theory, params)[source]

Validate that theoretical C2 values are physically reasonable.

Checks for NaN/inf values and enforces the heterodyne C2 range constraint [-1.0, 10.0]. Heterodyne C2 can go negative due to the velocity phase term, unlike homodyne where C2 >= 0.

Parameters:
  • c2_theory (Array) – Theoretical C2 array from model evaluation.

  • params (Array) – Parameter array used to produce c2_theory (logged on failure for diagnostics).

Return type:

bool

Returns:

True if the output passes all checks, False otherwise.

heterodyne.optimization.cmc.model.get_model_param_count(n_phi, per_angle_mode='individual')[source]

Return total number of sampled parameters for the model.

Accounts for per-angle mode semantics when counting contrast/offset parameters that are sampled in addition to the shared physics parameters.

Per-angle mode contributions:

  • "constant" — 0 per-angle params (fixed contrast/offset).

  • "constant_averaged" — 0 per-angle params (fixed averaged contrast/offset).

  • "auto" — physics params only (contrast/offset live in the parameter space, already counted).

  • "individual"2 * n_phi per-angle params (contrast_z + offset_z per angle).

Parameters:
  • n_phi (int) – Number of scattering angles.

  • per_angle_mode (str) – One of "constant", "constant_averaged", "auto", "individual".

Return type:

int

Returns:

Total number of sampled parameters (int).

Raises:

ValueError – If per_angle_mode is not recognised.

Priors

NLSQ-informed prior construction for heterodyne CMC analysis.

Builds NumPyro distribution dictionaries from NLSQ warm-start results or from the parameter registry defaults, including log-space priors for parameters flagged with log_space=True.

heterodyne.optimization.cmc.priors.build_nlsq_informed_priors(nlsq_result, param_space, width_factor=2.0)[source]

Build priors centered on NLSQ point estimates.

For each varying parameter, constructs a truncated Normal prior centered on the NLSQ best-fit value with width equal to the NLSQ uncertainty multiplied by width_factor. When NLSQ uncertainty is unavailable, falls back to registry prior_std or a fraction of the parameter range.

Parameters:
  • nlsq_result (NLSQResult) – Converged NLSQ result with parameter values and (optionally) uncertainties.

  • param_space (ParameterSpace) – Parameter space defining which parameters vary and their physical bounds.

  • width_factor (float) – Multiplier on NLSQ uncertainty to set prior width. Larger values give more diffuse priors. Default 2.0 gives a prior that spans roughly 4 sigma around the NLSQ estimate.

Return type:

dict[str, Distribution]

Returns:

Dictionary mapping parameter names to NumPyro distributions. Only includes parameters in param_space.varying_names.

heterodyne.optimization.cmc.priors.build_default_priors(param_space, registry=None, use_log_space_priors=True)[source]

Build default priors from the parameter registry.

Uses prior_mean and prior_std from each parameter’s ParameterInfo. All bounded parameters use TruncatedNormal so that temper_priors can scale them by sqrt(K) for Consensus Monte Carlo sharding.

Codex S1: when use_log_space_priors=True (default), parameters flagged log_space=True in the registry (D0_ref, D0_sample, v0) are overridden with LogNormal priors via build_log_space_priors(). LogNormal mass-matrix conditioning is much better for prefactors that span several orders of magnitude. Pass use_log_space_priors=False to fall back to TruncatedNormal uniformly.

The reparameterized path (CMCConfig.use_reparam=True) is unaffected: it samples log_X_at_tref directly and never calls this function for those parameters.

Parameters:
  • param_space (ParameterSpace) – Parameter space defining which parameters vary and their physical bounds.

  • registry (ParameterRegistry | None) – Parameter registry to read metadata from. Defaults to DEFAULT_REGISTRY.

  • use_log_space_priors (bool) – Apply log-space priors to log_space=True parameters from the registry. Default True.

Return type:

dict[str, Distribution]

Returns:

Dictionary mapping parameter names to NumPyro distributions. Only includes parameters in param_space.varying_names.

heterodyne.optimization.cmc.priors.build_log_space_priors(param_names, registry=None)[source]

Build log-normal priors for parameters marked log_space=True.

For parameters where the registry’s log_space flag is set, this constructs a LogNormal distribution whose median matches the registry prior_mean (or the parameter default) and whose spread corresponds to the registry prior_std.

Parameters not flagged as log_space are silently skipped.

Parameters:
  • param_names (list[str]) – List of parameter names to consider.

  • registry (ParameterRegistry | None) – Parameter registry. Defaults to DEFAULT_REGISTRY.

Return type:

dict[str, Distribution]

Returns:

Dictionary mapping parameter names to LogNormal distributions. Only includes parameters where log_space=True.

heterodyne.optimization.cmc.priors.temper_priors(priors, num_shards)[source]

Scale prior widths for Consensus MC shard sub-posteriors.

Each shard sees 1/K of the data, so the prior should be tempered by sqrt(K) to maintain proper posterior geometry when K sub-posteriors are combined via the consensus step.

Supported distribution types and their tempering rules:

  • TruncatedNormal — scale multiplied by sqrt(K).

  • LogNormal — scale multiplied by sqrt(K).

  • Uniform — left unchanged (uninformative; no tempering needed).

  • All others — kept unchanged with a warning logged.

Parameters:
  • priors (dict[str, Distribution]) – Dict of NumPyro distributions, one per varying parameter.

  • num_shards (int) – Number of CMC shards (K). Must be >= 1.

Return type:

dict[str, Distribution]

Returns:

New dict with tempered distributions. Existing dict is not mutated.

heterodyne.optimization.cmc.priors.validate_priors(priors, param_space)[source]

Validate prior distributions against parameter space.

Checks:

  1. All varying parameters have a corresponding prior.

  2. Prior support overlaps with the parameter bounds (non-empty intersection).

  3. No degenerate (effectively zero-width) priors.

A prior is considered degenerate when its extractable scale is below 1e-12. Uniform priors are never degenerate.

Parameters:
  • priors (dict[str, Distribution]) – Dict of NumPyro distributions.

  • param_space (ParameterSpace) – Defines varying parameter names and their bounds.

Return type:

list[str]

Returns:

List of warning/error strings. Empty list means all checks passed.

heterodyne.optimization.cmc.priors.summarize_priors(priors)[source]

Format a human-readable summary of prior distributions.

For each prior, reports the distribution type and, where applicable, the mean, standard deviation, and support interval.

Parameters:

priors (dict[str, Distribution]) – Dict of NumPyro distributions.

Return type:

str

Returns:

Multi-line string with one row per parameter.

heterodyne.optimization.cmc.priors.get_param_names_in_order(vary_flags=None)[source]

Return the ordered list of parameter names that are set to vary.

Iteration order matches the registry’s insertion order, which follows the canonical group ordering defined in parameter_names.py (reference → sample → velocity → fraction → angle → scaling).

Parameters:

vary_flags (dict[str, bool] | None) – Optional override dict mapping parameter name to a bool indicating whether that parameter varies. Parameters absent from vary_flags fall back to the registry’s vary_default attribute. Pass None to use registry defaults for all parameters.

Return type:

list[str]

Returns:

List of parameter names for which the effective vary flag is True, in registry order.

heterodyne.optimization.cmc.priors.validate_initial_value_bounds(init_values, param_specs=None)[source]

Check that each initial value lies within the parameter’s bounds.

Parameters:
  • init_values (dict[str, float]) – Mapping of parameter name to proposed initial value.

  • param_specs (dict[str, Any] | None) – Optional dict of {name: {min_bound, max_bound}} overrides. When not provided, bounds are read from DEFAULT_REGISTRY.

Return type:

dict[str, list[str]]

Returns:

Mapping from parameter name to a list of warning strings. An empty dict indicates all values are within bounds.

heterodyne.optimization.cmc.priors.build_init_values_dict(nlsq_values=None, vary_flags=None, fallback='prior_mean')[source]

Build an initial-values dict for NUTS warm-starting.

For each varying parameter the value is resolved in order:

  1. NLSQ estimate from nlsq_values (if available).

  2. Registry prior_mean when fallback="prior_mean" and prior_mean is not None.

  3. Registry default value.

All resolved values are validated against bounds and clamped when necessary, with a logged warning per clamped parameter.

Parameters:
  • nlsq_values (dict[str, float] | None) – Optional NLSQ MAP estimates keyed by parameter name.

  • vary_flags (dict[str, bool] | None) – Optional dict controlling which parameters vary (see get_param_names_in_order()).

  • fallback (str) – Strategy for parameters absent from nlsq_values. "prior_mean" uses the registry prior mean (default); "default" uses the registry default value.

Return type:

dict[str, float]

Returns:

Dict mapping each varying parameter name to its initial value, ready to pass to run_with_init_values().

heterodyne.optimization.cmc.priors.extract_nlsq_values_for_cmc(nlsq_result)[source]

Extract parameter values and uncertainties from an NLSQ result.

Converts the array-based NLSQResult into plain float dictionaries suitable for CMC warm-starting. NaN and inf values are filtered out so that downstream prior construction never receives non-finite inputs.

Parameters:

nlsq_result (NLSQResult) – Converged NLSQ result with .parameters, .parameter_names, and optionally .uncertainties.

Returns:

  • values maps parameter name to its fitted float value (non-finite entries excluded).

  • uncertainties maps parameter name to its float uncertainty, or None when the NLSQ result carries no uncertainty information. Non-finite entries are excluded.

Return type:

tuple[dict[str, float], dict[str, float] | None]

heterodyne.optimization.cmc.priors.estimate_per_angle_scaling(data_dict, angle_keys=None)[source]

Estimate contrast and offset scaling per scattering angle.

Uses simple heuristics on the raw g2 correlation data to provide starting-point estimates for the contrast and offset scaling parameters:

  • contrast estimate ≈ max(g2) - min(g2) over the full lag range.

  • offset estimate ≈ mean of the last 10 % of g2 values (long-lag baseline), clamped to [0, 1].

These are heuristics suitable for warm-starting, not MAP estimates. The NLSQ/MCMC optimisation will refine them.

Parameters:
  • data_dict (dict[str, Any]) –

    Dict mapping angle keys to g2 data. Each value may be:

    • a 1-D array-like (n_lags,) of g2 values, or

    • a dict with a "g2" key holding such an array.

  • angle_keys (list[str] | None) – Subset of keys in data_dict to process. Defaults to all keys when None.

Return type:

dict[str, tuple[float, float]]

Returns:

Mapping of angle_key -> (contrast_estimate, offset_estimate). Keys for which data could not be parsed are silently omitted.

heterodyne.optimization.cmc.priors.estimate_contrast_offset_from_data(c2_data, t1, t2, contrast_bounds=(0.0, 1.0), offset_bounds=(0.5, 1.5), lag_floor_quantile=0.80, lag_ceiling_quantile=0.20, value_quantile_low=0.10, value_quantile_high=0.90)[source]

Estimate contrast and offset from C2 data via physics-informed quantile analysis.

Uses the correlation decay: C2 = contrast × g1² + offset. At large lags g1² → 0 so C2 → offset; at small lags g1² ≈ 1 so C2 ≈ contrast + offset.

Returns:

(contrast_est, offset_est) each clipped to their bounds.

Return type:

tuple[float, float]

heterodyne.optimization.cmc.priors.validate_init_values_order(init_values, expected_names)[source]

Validate that init-values key order matches the expected parameter order.

Homodyne CMC parity helper. Python 3.7+ dicts preserve insertion order, so positional consumers of init_values (e.g. when zipping with NLSQ arrays) depend on a stable iteration order. This check makes ordering bugs fail loudly with a descriptive error instead of producing silently wrong parameter bindings.

Parameters:
  • init_values (dict[str, float]) – Initial-values mapping, typically the output of build_init_values_dict().

  • expected_names (list[str]) – Required parameter order — usually get_param_names_in_order() for the active mode.

Raises:

ValueError – When the key count or per-position order disagrees with expected_names. The error names the first mismatching index and quotes the full lists for fast diagnosis.

Return type:

None

heterodyne.optimization.cmc.priors.BOUNDARY_INTERIOR_MARGIN: float = 0.05

5% of the bound range (linear) or 5% of the log-range (geometric, for log_space=True parameters). NUTS leapfrog step-size adaptation collapses if the chain initialises at a TruncatedNormal boundary; this margin keeps the warm-start away from the reflecting wall.

Type:

Default boundary-clamp margin

heterodyne.optimization.cmc.priors.clamp_params_to_interior(params, parameter_names, *, margin=BOUNDARY_INTERIOR_MARGIN)[source]

Shift parameter values inward from hard bounds (raw-array path).

Lower-level companion to clamp_to_interior() (the NLSQResult entry point). Named separately so the two public APIs don’t share a symbol — callers without an NLSQResult (e.g. direct Python users passing raw arrays to fit_cmc_jax()) use this entry.

Linear-scale parameters are clamped to [lo + margin * range, hi - margin * range] where range = hi - lo. Log-space parameters (D0_ref, D0_sample, v0 — registry log_space=True, positive min_bound) use a geometric margin so a 5% linear fraction of a multi-decade range does not produce a 500× clamp target.

Parameters:
  • params (ndarray) – Array of parameter values, shape (n,).

  • parameter_names (list[str]) – Names corresponding to each entry in params. Names not in the registry are passed through.

  • margin (float) – Fraction of bound range to keep clear of each wall. Default BOUNDARY_INTERIOR_MARGIN.

Return type:

tuple[ndarray, list[str]]

Returns:

(new_params, clamped_names) — a fresh array with values shifted inward and the list of names that were actually moved.

heterodyne.optimization.cmc.priors.clamp_to_interior(result, fixed_param_overrides=None, *, margin=BOUNDARY_INTERIOR_MARGIN)[source]

Return a copy of result with parameters shifted inward from hard bounds.

NUTS step-size collapses when the chain initialises at a TruncatedNormal boundary. Linear-scale parameters are clamped to [min_bound + margin, max_bound - margin] where margin is 5% of the bound range. Log-space parameters (D0, v0) use a geometric margin so a linear fraction of a multi-decade range does not produce an absurd clamp target. Both ensure the leapfrog step-size adaptation starts well away from the reflecting wall.

fixed_param_overrides maps parameter names to values from the current model config’s fixed_parameters. When provided, any NLSQ result value for a fixed parameter is replaced with the config value before bounds clamping. This prevents a stale nlsq_data.npz (fitted in a prior run where the parameter was free) from propagating a superseded value into CMC initialisation, which can place the warm-start outside the reparameterised prior support and cause log-prior = −∞ → BFMI = 0 across all shards.

Parameters:
  • result (NLSQResult) – The NLSQ fit result to clamp.

  • fixed_param_overrides (dict[str, float] | None) – Optional map from parameter name to config-fixed value applied before clamping.

  • margin (float) – Fraction of bound range to keep clear of each wall. Default BOUNDARY_INTERIOR_MARGIN.

Return type:

NLSQResult

Returns:

A new NLSQResult with parameters clamped if any were out of the safe interior; the original result itself when no clamping was needed.

class heterodyne.optimization.cmc.priors.PriorBuilder[source]

Bases: object

Construct NumPyro priors from a parameter registry.

Parameters:
  • registry (ParameterRegistry | None) – Parameter registry (defaults to heterodyne.config.parameter_registry.DEFAULT_REGISTRY).

  • use_log_space_priors (bool) – When True (default), parameters flagged log_space=True in the registry are returned as LogNormal distributions instead of TruncatedNormal — mass-matrix conditioning for multi-decade prefactors (D0_ref, D0_sample, v0). See codex S1.

Raises:

RuntimeError – When the registry and parameter_space._DEFAULT_PRIOR_SPECS disagree on (prior_mean, prior_std) vs (loc, scale). This is Rule 9 from CLAUDE.md — the dual-prior system must stay in sync. Tolerance is rel_tol=1e-6, abs_tol=1e-9.

__init__(registry=None, use_log_space_priors=True)[source]
build(param_space)[source]

Build priors for the varying parameters in param_space.

Parameters:

param_space (ParameterSpace) – ParameterSpace defining which parameters vary and their physical bounds.

Return type:

dict[str, Distribution]

Returns:

Dictionary mapping each varying parameter name to its NumPyro distribution. D0_ref/D0_sample/v0 are LogNormal when use_log_space_priors=True, TruncatedNormal-family otherwise; remaining parameters are always TruncatedNormal.

heterodyne.optimization.cmc.priors.build_default_priors_via_builder(param_space, registry=None, use_log_space_priors=True)[source]

Functional shim around PriorBuilder for callers that want a one-shot factory. Constructing the builder runs the sync gate, so each call validates the dual-prior invariant.

Return type:

dict[str, Distribution]

heterodyne.optimization.cmc.priors.build_log_space_priors_via_builder(param_names, registry=None)[source]

Functional shim for the legacy log-space-only helper.

The construction-time gate runs here too — calling this entry point is enough to fail loud on a desynced registry/spec pair.

Return type:

dict[str, Distribution]

Sampler

High-level NUTS sampler wrapper for heterodyne CMC analysis.

Provides a SamplingPlan dataclass for sampling hyperparameters and a NUTSSampler class that wraps NumPyro’s MCMC with ergonomic factories, automatic chain initialization with perturbation, and ArviZ diagnostics.

class heterodyne.optimization.cmc.sampler.SamplingPlan[source]

Bases: object

Hyperparameters for NUTS sampling.

Immutable configuration that fully specifies a sampling run.

num_warmup

Number of warmup (adaptation) steps per chain.

num_samples

Number of posterior draws per chain after warmup.

num_chains

Number of independent MCMC chains.

target_accept

Target acceptance probability for dual-averaging step-size adaptation. Values in [0.6, 0.95] are typical.

max_tree_depth

Maximum binary tree depth for NUTS. Higher values allow longer trajectories but increase per-step cost.

adapt_step_size

Whether to use dual-averaging step-size adaptation during warmup.

dense_mass

Whether to estimate a dense (full) mass matrix during warmup, or use a diagonal approximation.

seed

Explicit random seed for reproducibility. If None, a cryptographically random seed is generated.

num_warmup: int = 500
num_samples: int = 1000
num_chains: int = 4
target_accept: float = 0.8
max_tree_depth: int = 10
adapt_step_size: bool = True
dense_mass: bool = True
chain_method: str = 'sequential'
seed: int | None = None
fast_warmup: bool = False

Rule 12 escape hatch propagated from CMCConfig.fast_warmup. When True, for_shard and AdaptiveSamplingPlan skip the dense-mass warmup floor. CI / pytest fast-mode only — not for production posteriors.

__post_init__()[source]

Validate hyperparameters.

Return type:

None

property effective_seed: int

Return the seed, generating one if not explicitly set.

classmethod from_config(config, n_data=None, n_params=None)[source]

Build a SamplingPlan from a CMCConfig.

Applies adaptive scaling when config.adaptive_sampling is True and n_data is provided: warmup and sample counts are scaled proportionally to the ratio n_data / _REFERENCE_SHARD_SIZE and clamped to the configured floors.

Parameters:
  • config (CMCConfig) – CMC configuration carrying all NUTS hyperparameters and adaptive-sampling knobs.

  • n_data (int | None) – Number of data points in this shard (or the full dataset when sharding is disabled). When None, no adaptive scaling is applied regardless of config.adaptive_sampling.

  • n_params (int | None) – Number of varying model parameters. Reserved for future dimension-aware scaling; currently unused.

Return type:

SamplingPlan

Returns:

Fully validated SamplingPlan.

for_shard(shard_size, full_size)[source]

Return a scaled-down plan appropriate for a single CMC shard.

Scales warmup and sample counts by sqrt(shard_size / full_size) to reflect the reduced information content of the shard. Counts are clamped to a minimum of max(1, num_x // 10) to avoid degenerate one-step runs.

Parameters:
  • shard_size (int) – Number of data points in this shard.

  • full_size (int) – Total number of data points across all shards.

Return type:

SamplingPlan

Returns:

New SamplingPlan with adjusted warmup/sample counts and the same seed and other hyperparameters.

Raises:

ValueError – If shard_size <= 0 or full_size <= 0.

__init__(num_warmup=500, num_samples=1000, num_chains=4, target_accept=0.8, max_tree_depth=10, adapt_step_size=True, dense_mass=True, chain_method='sequential', seed=None, fast_warmup=False)
class heterodyne.optimization.cmc.sampler.NUTSSampler[source]

Bases: object

High-level NUTS sampler wrapping NumPyro’s MCMC.

Manages kernel construction, chain initialization with perturbation, sampling execution, and ArviZ diagnostic extraction.

Use the from_plan() factory for the standard construction path.

__init__(mcmc, plan)[source]
property plan: SamplingPlan

The sampling plan used to configure this sampler.

classmethod from_plan(plan, model, init_strategy='init_to_median', chain_method=None)[source]

Create a NUTSSampler from a SamplingPlan and NumPyro model.

Parameters:
  • plan (SamplingPlan) – Sampling hyperparameters.

  • model (Callable[..., Any]) – NumPyro model function (callable with no required args).

  • init_strategy (str) – NumPyro initialization strategy name. One of "init_to_median", "init_to_sample", "init_to_value".

  • chain_method (str | None) – NumPyro chain execution method. "sequential" for single device, "parallel" for multi-device.

Return type:

NUTSSampler

Returns:

Configured NUTSSampler ready for run().

run(rng_key=None, init_params=None)[source]

Run MCMC sampling.

If init_params are provided, small random perturbations are added per chain to break symmetry and improve exploration.

Parameters:
  • rng_key (Array | None) – JAX PRNG key. If None, one is generated from the plan’s seed.

  • init_params (dict[str, Array] | None) – Optional initial values for each chain. Keys are parameter names; values should be scalars or have shape (num_chains,).

Return type:

dict[str, Any]

Returns:

Dictionary of posterior samples (ungrouped).

Raises:

RuntimeError – If sampling fails.

run_with_init_values(init_values, rng_key=None)[source]

Run MCMC seeded from NLSQ warm-start values.

Validates that the initial log density is finite before launching full sampling, raising early with a diagnostic message if not.

Parameters:
  • init_values (dict[str, float]) – NLSQ MAP estimates keyed by parameter name. Values should be in the same space as the NumPyro model samples (physics space, or reparameterized space if the model uses reparameterization).

  • rng_key (Array | None) – JAX PRNG key. Generated from the plan seed if None.

Return type:

dict[str, Any]

Returns:

Dictionary of posterior samples (ungrouped).

Raises:

RuntimeError – If the initial log density is not finite or if sampling itself fails.

get_divergence_stats()[source]

Extract divergence rate and tree-depth statistics from the last run.

Requires that run() or run_with_init_values() has been called.

Returns:

"divergence_rate"

Fraction of post-warmup transitions that were divergent. Zero when no divergences were recorded.

"mean_tree_depth"

Mean NUTS tree depth across all post-warmup samples and chains. Values near max_tree_depth indicate the trajectory is being truncated.

"max_tree_depth_fraction"

Fraction of samples that hit the maximum tree depth (plan.max_tree_depth).

Return type:

dict[str, float]

Raises:

RuntimeError – If called before run().

get_diagnostics()[source]

Extract ArviZ InferenceData for convergence diagnostics.

Return type:

InferenceData

Returns:

ArviZ InferenceData containing posterior samples, sample stats (energy, divergences), and warmup statistics.

Raises:

RuntimeError – If called before run().

log_adapter_diagnostics(run_logger=None)[source]

Log NUTS adapter state at INFO level — homodyne CMC parity helper.

Reports the adapted step_size per chain, a compact summary of the adapted inverse mass matrix, and per-step accept_prob / num_steps / potential_energy statistics when the corresponding extra fields are present.

Parameters:

run_logger (Any | None) – Logger to emit through. None uses this module’s logger.

Raises:

RuntimeError – If called before run().

Return type:

None

property mcmc: MCMC

Access the underlying NumPyro MCMC object.

class heterodyne.optimization.cmc.sampler.AdaptiveSamplingPlan[source]

Bases: object

Sampling plan that adjusts warmup/sample counts based on shard size.

Wraps a base SamplingPlan and scales it down proportionally when the shard is smaller than a reference size, while respecting floors derived from parameter count.

The scaling rule is:

scale = sqrt(shard_size / reference_shard_size)
num_warmup  = max(min_warmup_floor,  int(base.num_warmup  * scale))
num_samples = max(min_samples_floor, int(base.num_samples * scale))

where min_warmup_floor = max(50, 5 * n_params) and min_samples_floor = max(100, 10 * n_params).

base_plan

Base SamplingPlan for a full-size shard.

shard_size

Number of data points in this shard.

n_params

Number of varying model parameters. Used to set minimum sample-count floors.

base_plan: SamplingPlan
shard_size: int
n_params: int
get_plan()[source]

Return a SamplingPlan adjusted for this shard.

Scaling is sub-linear (square-root) so that small shards still receive enough samples to characterise the posterior, while large shards are not penalised by excessive warmup.

The floor on warmup is max(50, 5 * n_params) — enough adaptation steps to approximate the mass matrix for a 14-parameter model (floor = 70 steps). The floor on samples is max(100, 10 * n_params) — enough draws for basic ESS diagnostics.

Return type:

SamplingPlan

Returns:

Scaled SamplingPlan.

__init__(base_plan, shard_size, n_params, _reference_shard_size=10000)
heterodyne.optimization.cmc.sampler.DIVERGENCE_RATE_TARGET: float = 0.01

Target divergence rate — below this level the run is considered healthy.

heterodyne.optimization.cmc.sampler.DIVERGENCE_RATE_HIGH: float = 0.05

Elevated divergence rate — triggers a retry in run_nuts_with_retry().

heterodyne.optimization.cmc.sampler.DIVERGENCE_RATE_CRITICAL: float = 0.1

Critical divergence rate — posterior geometry is likely incompatible with HMC.

class heterodyne.optimization.cmc.sampler.SamplingStats[source]

Bases: object

Summary statistics from a completed NUTS sampling run.

num_samples

Number of posterior draws per chain (post-warmup).

num_warmup

Number of warmup steps per chain.

num_divergences

Total divergent transitions across all chains.

divergence_rate

Fraction of post-warmup transitions that diverged.

mean_accept_prob

Mean Metropolis acceptance probability.

max_tree_depth_fraction

Fraction of samples that hit the maximum NUTS tree depth.

wall_time_seconds

Elapsed wall-clock time for the run.

num_samples: int
num_warmup: int
num_divergences: int
divergence_rate: float
mean_accept_prob: float
max_tree_depth_fraction: float
wall_time_seconds: float
property is_healthy: bool

Return True when divergence rate and acceptance probability are acceptable.

Criteria:

  • divergence_rate < 0.05 (below DIVERGENCE_RATE_HIGH)

  • mean_accept_prob > 0.6

__init__(num_samples, num_warmup, num_divergences, divergence_rate, mean_accept_prob, max_tree_depth_fraction, wall_time_seconds)
heterodyne.optimization.cmc.sampler.run_nuts_with_retry(sampler, model_fn, model_kwargs, max_retries=3, target_accept_increment=0.05, *, step_size_factor=None)[source]

Run NUTS sampling with automatic step-size reduction on high divergence.

Executes run() and checks the divergence rate after each attempt. When the rate exceeds DIVERGENCE_RATE_HIGH, a new NUTSSampler is built with target_accept RAISED by target_accept_increment (which drives dual averaging toward a SMALLER step size — the mathematically correct response to high divergence rate) and the run is retried. After max_retries attempts the result with the lowest divergence rate is returned regardless of health. Mirrors homodyne run_nuts_with_retry (sampler.py:1311-1314).

The model_fn is re-used across retries so it must be stateless (i.e. a pure NumPyro model function with no side effects).

Parameters:
  • sampler (NUTSSampler) – Configured NUTSSampler for the first attempt.

  • model_fn (Any) – NumPyro model callable. Not called directly here but passed to NUTSSampler.from_plan() for retry instances.

  • model_kwargs (dict[str, Any]) – Keyword arguments forwarded to the model via run(). Currently unused by run() (which takes rng_key and init_params); included for forward compatibility.

  • max_retries (int) – Maximum number of additional attempts after the first run. Total runs = max_retries + 1.

  • target_accept_increment (float) – Additive increase applied to target_accept each retry (e.g. 0.80 → 0.85 → 0.90 → 0.95). Raising the target acceptance LOWERS the dual-averaging step size, which is the mathematically correct response to high divergence rates. Must be in (0, 0.5). The target is clamped at 0.99 to avoid pathological tiny step sizes.

  • step_size_factor (float | None) – DEPRECATED keyword-only alias. Earlier versions multiplied target_accept by this factor (with factor < 1) on retry, which drove dual averaging toward a LARGER step size — the OPPOSITE of what high divergence rate calls for. When supplied, target_accept_increment is derived as (1 - step_size_factor) * 0.1 so legacy callers see a corrected mathematical direction.

Return type:

tuple[dict[str, Any], SamplingStats]

Returns:

Tuple of (samples_dict, SamplingStats) for the best attempt (lowest divergence rate).

heterodyne.optimization.cmc.sampler.compute_mcmc_safe_initial_values(initial_values, *, q, dt, time_grid, target_g1=0.5, g1_threshold=0.1)[source]

Detect/repair vanishing-g1 initial parameters for heterodyne NUTS.

Inspects both the reference and sample transport components and, when either drives g1 below g1_threshold at a typical lag, rescales its (D0, D_offset) pair so that g1 target_g1. Returns None when no adjustment is needed so callers can short-circuit.

Parameters:
  • initial_values (dict[str, float] | None) – NLSQ warm-start dictionary. Expected keys include D0_ref, alpha_ref, D_offset_ref, D0_sample, alpha_sample, D_offset_sample.

  • q (float) – Wavevector magnitude (Å⁻¹).

  • dt (float) – Lag-time step (s).

  • time_grid (Any | None) – Time grid for integration.

  • target_g1 (float) – Target g1 value when rescaling.

  • g1_threshold (float) – Threshold below which rescaling fires.

Return type:

dict[str, float] | None

Returns:

A new dictionary with adjusted D0_* / D_offset_* values when adjustment is required, else None.

Diagnostics

Convergence diagnostics for CMC analysis.

class heterodyne.optimization.cmc.diagnostics.ConvergenceReport[source]

Bases: object

Report of convergence diagnostic checks.

passed: bool
r_hat_passed: bool
ess_passed: bool
bfmi_passed: bool
messages: list[str]
__init__(passed, r_hat_passed, ess_passed, bfmi_passed, messages)
heterodyne.optimization.cmc.diagnostics.validate_convergence(result, r_hat_threshold=1.1, min_ess=100, min_bfmi=0.3)[source]

Validate MCMC convergence from CMC result.

Checks: 1. R-hat (Gelman-Rubin statistic) < threshold for all parameters 2. Effective sample size (ESS) > minimum for all parameters 3. Bayesian Fraction of Missing Information (BFMI) > minimum

Parameters:
  • result (CMCResult) – CMC result with diagnostics

  • r_hat_threshold (float) – Maximum acceptable R-hat

  • min_ess (int) – Minimum acceptable ESS

  • min_bfmi (float) – Minimum acceptable BFMI

Return type:

ConvergenceReport

Returns:

ConvergenceReport

heterodyne.optimization.cmc.diagnostics.compute_posterior_contraction(result, prior_std)[source]

Compute Posterior Contraction Ratio for each parameter.

PCR = 1 - posterior_std / prior_std

Interpretation:

~1.0 = strongly constrained by data ~0.0 = poorly identified (prior dominates) <0 = possible model misspecification (posterior wider than prior)

Parameters:
  • result (CMCResult) – CMC result with posterior_std.

  • prior_std (dict[str, float]) – Dict of prior standard deviations by parameter name.

Return type:

dict[str, float]

Returns:

Dict mapping parameter name to PCR value.

heterodyne.optimization.cmc.diagnostics.compute_r_hat(samples)[source]

Compute rank-normalized R-hat from samples.

Delegates to arviz.rhat() which implements the recommended rank-normalized split-R-hat from Vehtari et al. (2021).

Parameters:

samples (ndarray) – Array of shape (n_chains, n_samples). Requires at least 2 chains; single-chain input returns NaN.

Return type:

float

Returns:

R-hat value (1.0 indicates convergence; >1.01 suggests issues)

heterodyne.optimization.cmc.diagnostics.compute_ess(samples)[source]

Compute effective sample size.

Delegates to arviz.ess() which uses FFT-based autocorrelation with Geyer’s initial monotone sequence estimator.

Parameters:

samples (ndarray) – 1D array of samples, or 2D array of shape (n_chains, n_draws)

Return type:

float

Returns:

Effective sample size (always >= 1.0)

heterodyne.optimization.cmc.diagnostics.compute_bfmi(energy)[source]

Compute Bayesian Fraction of Missing Information.

Delegates to arviz.bfmi(). Values < 0.3 indicate potential problems with HMC sampling.

Parameters:

energy (ndarray) – Array of HMC energies (1D or 2D)

Return type:

float

Returns:

BFMI value (1.0 for constant energy)

class heterodyne.optimization.cmc.diagnostics.DivergenceReport[source]

Bases: object

Divergence rate analysis.

divergence_rate: float
n_divergent: int
n_total: int
severity: str
messages: list[str]
__init__(divergence_rate, n_divergent, n_total, severity, messages)
heterodyne.optimization.cmc.diagnostics.analyze_divergences(samples)[source]

Analyse divergent transitions from MCMC samples.

Accepts either a raw samples dict that may contain a "diverging" field (shape (n_chains, n_draws)) or a CMCResult whose extra_fields attribute holds that field.

Divergence rate is computed globally (all chains combined) and per-chain for contextual messages.

Parameters:

samples (dict[str, ndarray] | CMCResult) – Samples dict with optional "diverging" boolean array, or a CMCResult.

Return type:

DivergenceReport

Returns:

DivergenceReport with severity classification and human-readable messages.

heterodyne.optimization.cmc.diagnostics.validate_convergence_sharded(results, r_hat_threshold=1.1, min_ess=100, min_bfmi=0.3)[source]

Validate convergence across CMC shards.

Runs validate_convergence() on each shard and returns a combined ConvergenceReport that reflects the worst-case R-hat and the minimum ESS observed across all shards. A single failing shard causes the combined report to fail.

Parameters:
  • results (list[CMCResult]) – One CMCResult per shard.

  • r_hat_threshold (float) – Forwarded to validate_convergence().

  • min_ess (int) – Forwarded to validate_convergence().

  • min_bfmi (float) – Forwarded to validate_convergence().

Return type:

ConvergenceReport

Returns:

Combined ConvergenceReport with worst-case statistics.

heterodyne.optimization.cmc.diagnostics.compute_trace_diagnostics(samples, lags=(1, 5, 10))[source]

Compute trace-level diagnostics for a single parameter’s samples.

Parameters:
  • samples (ndarray) – Array of shape (n_chains, n_draws) or (n_draws,). When 1-D, treated as a single chain.

  • lags (tuple[int, ...]) – Autocorrelation lags to evaluate. Defaults to (1, 5, 10).

Returns:

autocorr

Dict mapping each lag to its mean autocorrelation across chains.

stationary

True if the absolute autocorrelation at lag 1 is below 0.5 for all chains (heuristic stationarity flag).

mixing_quality

One of "good", "moderate", or "poor" based on lag-1 autocorrelation magnitude.

n_chains

Number of chains.

n_draws

Number of draws per chain.

mean

Grand mean across all draws.

std

Grand standard deviation across all draws.

Return type:

dict[str, object]

heterodyne.optimization.cmc.diagnostics.compute_pair_correlations(samples)[source]

Compute pairwise Pearson correlations between parameters.

Flattens all chains and draws for each parameter before computing correlations, so the result is chain-agnostic.

Useful for detecting parameter degeneracy: correlations with |r| > 0.9 indicate near-redundant parameters.

Parameters:

samples (dict[str, ndarray]) – Mapping from parameter name to array of shape (n_chains, n_draws) or (n_draws,).

Return type:

dict[str, dict[str, float]]

Returns:

Nested dict corr[param_a][param_b] = r where r is the Pearson correlation coefficient in [-1, 1]. The matrix is symmetric with ones on the diagonal. Returns an empty dict if fewer than two parameters are provided.

heterodyne.optimization.cmc.diagnostics.check_convergence(r_hat, ess_bulk, divergences, n_samples, n_chains, max_rhat=DEFAULT_MAX_RHAT, min_ess=DEFAULT_MIN_ESS, max_divergence_rate=DEFAULT_MAX_DIVERGENCE_RATE, num_shards=1)[source]

Check convergence criteria and return (status, warnings).

Returns:

(status, warnings) where status is "converged" | "divergences" | "not_converged". "divergences" takes priority over "not_converged" when both R-hat/ESS failures and excess divergences are present simultaneously.

Return type:

tuple[str, list[str]]

heterodyne.optimization.cmc.diagnostics.create_diagnostics_dict(r_hat, ess_bulk, ess_tail, divergences, convergence_status, warnings, n_chains, n_warmup, n_samples, warmup_time, sampling_time, num_shards=1)[source]

Build a diagnostics dictionary suitable for JSON serialization.

Return type:

dict[str, Any]

heterodyne.optimization.cmc.diagnostics.summarize_diagnostics(r_hat, ess_bulk, divergences, n_samples, n_chains, num_shards=1)[source]

Create human-readable diagnostics summary (homodyne-parity).

Mirrors homodyne diagnostics.summarize_diagnostics. Used by CLI summary writers that emit a one-liner like Diagnostics: R-hat(max)=1.02, ESS(min)=420, divergences=3 (0.3%).

Parameters:
  • r_hat (dict[str, float]) – Per-parameter R-hat values. NaN values are skipped.

  • ess_bulk (dict[str, float]) – Per-parameter bulk ESS values. NaN values are skipped.

  • divergences (int) – Total divergence count across all chains and shards.

  • n_samples (int) – Posterior samples per chain.

  • n_chains (int) – Number of chains per shard.

  • num_shards (int) – Number of CMC shards (default 1 for non-sharded).

Returns:

Single-line diagnostics summary suitable for log emission.

Return type:

str

heterodyne.optimization.cmc.diagnostics.get_convergence_recommendations(max_rhat, min_ess, divergences, n_samples, n_chains, num_shards=1)[source]

Generate actionable recommendations for convergence issues.

Return type:

list[str]

heterodyne.optimization.cmc.diagnostics.log_analysis_summary(convergence_status, r_hat, ess_bulk, divergences, n_samples, n_chains, n_shards, shards_succeeded, execution_time)[source]

Log a formatted CMC analysis summary at INFO/ERROR level.

Return type:

None

class heterodyne.optimization.cmc.diagnostics.BimodalResult[source]

Bases: object

Result of a bimodality test for a single parameter’s samples.

param_name

Name of the parameter tested.

is_bimodal

True if the 2-component GMM is favoured by BIC.

bic_unimodal

BIC of the 1-component (unimodal) Gaussian mixture.

bic_bimodal

BIC of the 2-component Gaussian mixture.

delta_bic

bic_unimodal - bic_bimodal. Positive values favour the bimodal model.

means

Tuple of the two component means (None when not bimodal).

weights

Tuple of the two component mixing weights (None when not bimodal).

param_name: str
is_bimodal: bool
bic_unimodal: float
bic_bimodal: float
delta_bic: float
means: tuple[float, float] | None
weights: tuple[float, float] | None
__init__(param_name, is_bimodal, bic_unimodal, bic_bimodal, delta_bic, means, weights)
heterodyne.optimization.cmc.diagnostics.detect_bimodal(samples, param_name, bic_threshold=10.0, min_weight=0.0, min_separation=0.0)[source]

Fit 1- and 2-component Gaussian mixtures and compare BIC.

Uses scikit-learn’s GaussianMixture to estimate Bayesian Information Criterion for unimodal vs bimodal models. A positive delta_bic larger than bic_threshold is treated as evidence for bimodality. Two optional post-conditions can tighten the criterion:

  • min_weight: the minor mode must have at least this mixture weight (0–0.5). Filters spurious bimodality from very unequal modes.

  • min_separation: the modes must be separated by at least this many sample standard deviations. Filters bimodality in flat or nearly uniform posteriors.

Parameters:
  • samples (ndarray) – 1-D array of posterior draws for the parameter.

  • param_name (str) – Name used for logging and result labelling.

  • bic_threshold (float) – Minimum delta_bic (BIC_unimodal − BIC_bimodal) required to declare bimodality. Default 10.0 corresponds to strong evidence on the Raftery (1995) BIC scale.

  • min_weight (float) – Minimum weight of the minor mode for the result to be flagged as bimodal. Default 0.0 (no weight filter).

  • min_separation (float) – Minimum mode separation in sample standard deviations. Default 0.0 (no separation filter).

Return type:

BimodalResult

Returns:

BimodalResult with fitted statistics.

Raises:

ImportError – If scikit-learn is not installed, with a hint on how to add the optional dependency.

heterodyne.optimization.cmc.diagnostics.check_shard_bimodality(shard_samples, bic_threshold=10.0, min_weight=0.0, min_separation=0.0)[source]

Detect bimodality for each parameter across all CMC shards.

Runs detect_bimodal() for every (parameter, shard) combination and aggregates results per parameter.

Parameters:
  • shard_samples (dict[int, dict[str, ndarray]]) – Mapping of shard index to a dict of {param_name: samples_array}. Samples may be 1-D (n_draws,) or 2-D (n_chains, n_draws); they are flattened before testing.

  • bic_threshold (float) – Forwarded to detect_bimodal().

Return type:

dict[str, list[BimodalResult]]

Returns:

Mapping from parameter name to a list of BimodalResult, one entry per shard (in shard-index order).

heterodyne.optimization.cmc.diagnostics.compute_nlsq_comparison_metrics(posterior_samples, nlsq_values)[source]

Compare posterior statistics against NLSQ point estimates.

For each parameter present in both posterior_samples and nlsq_values, computes:

  • posterior_mean — mean of the flattened posterior draws.

  • posterior_std — standard deviation of the flattened draws.

  • nlsq_value — the NLSQ point estimate.

  • z_score|nlsq_value - posterior_mean| / posterior_std. NaN when posterior_std == 0.

  • within_hdi — 1.0 if the NLSQ value falls inside the 95 % HDI, 0.0 otherwise.

Parameters:
  • posterior_samples (dict[str, ndarray]) – Mapping of parameter name to sample array of shape (n_chains, n_draws) or (n_draws,).

  • nlsq_values (dict[str, float]) – NLSQ MAP estimates keyed by parameter name.

Return type:

dict[str, dict[str, float]]

Returns:

Nested dict result[param_name][metric_name] = value. Only parameters present in both inputs are included.

heterodyne.optimization.cmc.diagnostics.compute_precision_analysis(posterior_samples)[source]

Compute precision metrics for each parameter’s posterior.

For each parameter, calculates:

  • mean — posterior mean.

  • std — posterior standard deviation.

  • cv — coefficient of variation = std / |mean|. inf when mean == 0.

  • hdi_width — width of the shortest interval containing 95 % of the posterior draws (highest-density interval).

Parameters:

posterior_samples (dict[str, ndarray]) – Mapping of parameter name to sample array of shape (n_chains, n_draws) or (n_draws,).

Return type:

dict[str, dict[str, float]]

Returns:

Nested dict result[param_name][metric_name] = value.

class heterodyne.optimization.cmc.diagnostics.ModeCluster[source]

Bases: object

A single mode from bimodal consensus combination.

mean

Per-parameter consensus mean for this mode.

std

Per-parameter consensus std.

weight

Fraction of shards supporting this mode (0-1).

n_shards

Number of shards in this cluster.

mean: dict[str, float]
std: dict[str, float]
weight: float
n_shards: int
__init__(mean, std, weight, n_shards)
class heterodyne.optimization.cmc.diagnostics.BimodalConsensusResult[source]

Bases: object

Result of mode-aware consensus combination.

modes

Mode clusters (typically 2) with per-mode consensus.

modal_params

Parameter names that triggered bimodal detection.

co_occurrence

Cross-parameter co-occurrence info.

modes: list[ModeCluster]
modal_params: list[str]
co_occurrence: dict[str, Any]
__init__(modes, modal_params, co_occurrence)
heterodyne.optimization.cmc.diagnostics.summarize_cross_shard_bimodality(bimodal_detections, n_shards, consensus_means=None, significance_threshold=0.05)[source]

Aggregate per-shard bimodal detections into a cross-shard summary.

Groups detections by parameter, computes mode statistics (mean of lower modes, mean of upper modes), and checks whether the consensus posterior mean falls between the modes (density trough).

Parameters:
  • bimodal_detections (dict[str, list[BimodalResult]]) – Mapping from parameter name to a list of BimodalResult (one per shard), as returned by check_shard_bimodality().

  • n_shards (int) – Total number of shards.

  • consensus_means (dict[str, float] | None) – Consensus posterior means for each parameter. Used to check if consensus falls in the density trough between modes.

  • significance_threshold (float) – Minimum separation significance (separation / pooled_std) for a bimodal split to be reported.

Returns:

  • "per_param": {param_name -> {fraction_bimodal, lower_mode_mean, upper_mode_mean, separation, significance, consensus_in_trough}}

  • "n_detections": Total bimodal detections across all params.

  • "n_shards": Number of shards.

Return type:

dict[str, Any]

heterodyne.optimization.cmc.diagnostics.cluster_shard_modes(bimodal_detections, shard_samples, param_bounds=None)[source]

Jointly cluster shards into two mode populations.

Uses the parameters that show bimodal behaviour to build a per-shard feature vector, then runs a simple 2-means clustering (no sklearn dependency) to partition shards.

Parameters:
  • bimodal_detections (dict[str, list[BimodalResult]]) – Mapping from parameter name to a list of BimodalResult as returned by check_shard_bimodality().

  • shard_samples (dict[int, dict[str, ndarray]]) – Per-shard samples mapping {shard_idx: {param_name: samples_array}}.

  • param_bounds (dict[str, tuple[float, float]] | None) – Optional per-parameter (lo, hi) bounds for normalization. If None, the global range across shards is used.

Return type:

tuple[list[int], list[int]]

Returns:

(cluster_0_indices, cluster_1_indices) where cluster 0 is the “lower” cluster (centroid with lower mean across features).

heterodyne.optimization.cmc.diagnostics.log_precision_analysis(analysis, log_fn=None, tolerance_pct=20.0)[source]

Format and emit the CMC vs NLSQ precision analysis report.

Homodyne CMC parity helper. Consumes the output of compute_precision_analysis() and produces a fixed-width report summarising z-scores, percent differences, uncertainty ratios, and posterior contraction ratios per parameter. Flags severe disagreements (z > 3, |diff| > tolerance_pct, ratio < 0.5x).

Parameters:
  • analysis (dict[str, dict[str, float]]) – Per-parameter precision metrics from compute_precision_analysis(). Each value is a dict with keys including cmc_mean, cmc_std, nlsq_value, z_score, relative_diff, uncertainty_ratio, posterior_contraction.

  • log_fn (Callable[[str], None] | None) – Logger function to emit the report through. None routes through this module’s logger at INFO level.

  • tolerance_pct (float) – Percent-difference threshold for the [WARN] marker. Default 20.0.

Return type:

str

Returns:

The formatted report as a multi-line string.

Reparameterization

Reference-time reparameterization for heterodyne CMC.

Breaks banana-shaped posteriors for correlated power-law pairs (D0/alpha) by sampling at a reference time t_ref where the product D(t_ref) = D0 * t_ref^alpha is well-constrained by data.

Adapted from homodyne/optimization/cmc/reparameterization.py for heterodyne’s 3 power-law pairs – D0_ref/alpha_ref, D0_sample/alpha_sample, and v0/beta.

class heterodyne.optimization.cmc.reparameterization.ReparamConfig[source]

Bases: object

Configuration for reference-time reparameterization.

enable_d_ref

Reparameterize D0_ref/alpha_ref pair.

enable_d_sample

Reparameterize D0_sample/alpha_sample pair.

enable_v_ref

Reparameterize v0/beta pair.

t_ref

Reference time (geometric mean of dt and t_max).

enable_d_ref: bool = True
enable_d_sample: bool = True
enable_v_ref: bool = True
t_ref: float = 1.0
property enabled_pairs: list[tuple[str, str]]

Return list of enabled (prefactor, exponent) pairs.

is_reparameterized(name)[source]

Check if a parameter participates in reparameterization.

Return type:

bool

get_reparam_name(prefactor)[source]

Get the reparameterized log-space name for a prefactor.

Return type:

str

__init__(enable_d_ref=True, enable_d_sample=True, enable_v_ref=True, t_ref=1.0)
heterodyne.optimization.cmc.reparameterization.compute_t_ref(dt, t_max, fallback_value=None)[source]

Compute reference time as geometric mean of dt and t_max.

t_ref = sqrt(dt * t_max)

This places t_ref in the middle of the logarithmic time range, where the correlation function is most sensitive to the transport parameters.

Parameters:
  • dt (float) – Time step (minimum lag time).

  • t_max (float) – Maximum lag time.

  • fallback_value (float | None) – Value to use if dt or t_max are invalid.

Return type:

float

Returns:

Reference time.

heterodyne.optimization.cmc.reparameterization.transform_nlsq_to_reparam_space(nlsq_values, nlsq_uncertainties, t_ref, config=None)[source]

Transform NLSQ point estimates to reparameterized space.

For each enabled power-law pair (A0, alpha):

log_A_at_tref = log(A0) + alpha * log(t_ref)

Delta-method propagation for uncertainty:

Var(log_A_at_tref) ≈ (sigma_A0/A0)² + (log(t_ref) * sigma_alpha)²

Non-reparameterized parameters pass through unchanged.

Parameters:
  • nlsq_values (dict[str, float]) – NLSQ best-fit values by parameter name.

  • nlsq_uncertainties (dict[str, float]) – NLSQ uncertainties by parameter name.

  • t_ref (float) – Reference time.

  • config (ReparamConfig | None) – Reparameterization config. Defaults to all pairs enabled.

Return type:

tuple[dict[str, float], dict[str, float]]

Returns:

Tuple of (transformed_values, transformed_uncertainties).

heterodyne.optimization.cmc.reparameterization.transform_to_sampling_space(params, config)[source]

Transform physics-space parameters to sampling (reparam) space.

Used for initializing MCMC chains from NLSQ results.

Parameters:
  • params (dict[str, float]) – Physics-space parameter values.

  • config (ReparamConfig) – Reparameterization config.

Return type:

dict[str, float]

Returns:

Sampling-space parameter values.

heterodyne.optimization.cmc.reparameterization.reparam_to_physics_jax(log_at_tref, alpha, t_ref)[source]

Back-transform reparameterized values to physics space (JAX).

A0 = exp(log_at_tref - alpha * log(t_ref))

Parameters:
  • log_at_tref (Array) – Log of the quantity at t_ref.

  • alpha (Array) – Power-law exponent.

  • t_ref (float) – Reference time.

Return type:

Array

Returns:

A0 (prefactor in physics space).

heterodyne.optimization.cmc.reparameterization.transform_to_physics_space(samples, config)[source]

Transform sampling-space posterior samples to physics space.

Vectorized over sample dimension. For each enabled pair, computes:

A0 = exp(log_at_tref - alpha * log(t_ref))

Non-reparameterized parameters pass through.

Parameters:
  • samples (dict[str, ndarray]) – Dict of posterior samples keyed by sampling-space names.

  • config (ReparamConfig) – Reparameterization config.

Return type:

dict[str, ndarray]

Returns:

Dict of physics-space samples.

heterodyne.optimization.cmc.reparameterization.D_OFFSET_RATIO_MIN: float = -0.99

Minimum allowed ratio. Slightly above -1 so the implied D_offset > -D_ref keeps the diffusion rate strictly positive at t_ref while preserving gradient information near the boundary.

heterodyne.optimization.cmc.reparameterization.d_offset_to_ratio(d_offset, d_ref)[source]

Convert an absolute offset to the ratio representation.

d_offset_ratio = d_offset / d_ref. Returns 0.0 when d_ref is non-positive (degenerate channel).

Return type:

float

heterodyne.optimization.cmc.reparameterization.ratio_to_d_offset(ratio, d_ref)[source]

Reconstruct the absolute offset from the ratio representation.

d_offset = ratio * d_ref. Returns 0.0 when d_ref is non-positive.

Return type:

float

heterodyne.optimization.cmc.reparameterization.heterodyne_offset_ratios_from_physics(params, t_ref)[source]

Compute D_offset_ratio for both reference and sample channels.

Evaluates each channel’s diffusion magnitude at t_ref as D_ref(t_ref) = D0 * t_ref**alpha and returns the ratios D_offset_*_ratio = D_offset_* / D_ref(t_ref). Channels whose D_ref(t_ref) is non-positive yield a 0.0 ratio so callers can fall back to direct sampling for that channel.

Parameters:
  • params (dict[str, float]) – Mapping containing D0_ref, alpha_ref, D_offset_ref, D0_sample, alpha_sample, D_offset_sample.

  • t_ref (float) – Reference time at which to evaluate D_ref.

Return type:

dict[str, float]

Returns:

{"D_offset_ref_ratio": float, "D_offset_sample_ratio": float}.

heterodyne.optimization.cmc.reparameterization.heterodyne_physics_offsets_from_ratios(ratios, physics, t_ref)[source]

Inverse of heterodyne_offset_ratios_from_physics().

Given D_offset_*_ratio values and the current physics-space (D0_*, alpha_*) parameters, returns the absolute D_offset_* values consistent with D_ref(t_ref).

Return type:

dict[str, float]

Scaling

Parameter scaling utilities for CMC warm-start initialization and contrast/offset estimation.

Smooth bounded parameter scaling for heterodyne CMC.

Replaces jnp.clip() (zero gradient at bounds) with tanh-based smooth bounding that is differentiable everywhere, allowing NUTS to adapt its mass matrix near parameter boundaries.

Adapted from homodyne/optimization/cmc/scaling.py.

class heterodyne.optimization.cmc.scaling.ParameterScaling[source]

Bases: object

Scaling specification for a single parameter.

Defines the mapping between z-space (standard normal for MCMC) and original physics space with smooth bounding.

name

Parameter name.

center

NLSQ best-fit value (center of prior).

scale

Prior width (NLSQ uncertainty × width_factor).

low

Lower bound in physics space.

high

Upper bound in physics space.

name: str
center: float
scale: float
low: float
high: float
to_normalized(value)[source]

Transform from physics space to z-space (normalized).

z = (value - center) / scale

Parameters:

value (float | Array) – Physics-space value.

Return type:

float | Array

Returns:

Normalized z-space value.

to_original(z_value)[source]

Transform from z-space to bounded original (physics) space.

raw = center + scale * z result = smooth_bound(raw, low, high)

Parameters:

z_value (Array) – Normalized z-space value.

Return type:

Array

Returns:

Bounded physics-space value.

__init__(name, center, scale, low, high)
heterodyne.optimization.cmc.scaling.smooth_bound(raw, low, high)[source]

Smooth bounding using tanh transform.

Maps (-inf, +inf) → (low, high) via:

mid + half * tanh((raw - mid) / half)

This is differentiable everywhere, unlike jnp.clip() which has zero gradient at bounds and kills NUTS adaptation.

Parameters:
  • raw (Array) – Unbounded input value.

  • low (float) – Lower bound.

  • high (float) – Upper bound.

Return type:

Array

Returns:

Bounded value in (low, high).

heterodyne.optimization.cmc.scaling.smooth_bound_inverse(value, low, high)[source]

Inverse of smooth_bound for initialization.

Recovers the raw (unbounded) value from a bounded value:

raw = mid + half * arctanh((value - mid) / half)

Parameters:
  • value (float) – Bounded value in (low, high).

  • low (float) – Lower bound.

  • high (float) – Upper bound.

Return type:

float

Returns:

Unbounded raw value.

heterodyne.optimization.cmc.scaling.compute_scaling_factors(space, nlsq_values=None, nlsq_uncertainties=None, width_factor=2.0)[source]

Build ParameterScaling for each varying parameter.

Uses NLSQ values as centers and NLSQ uncertainties × width_factor as scale. Falls back to bounds midpoint and range/6 when NLSQ results are unavailable.

Parameters:
  • space (ParameterSpace) – Parameter space with bounds and varying flags.

  • nlsq_values (dict[str, float] | None) – NLSQ best-fit values by name.

  • nlsq_uncertainties (dict[str, float] | None) – NLSQ uncertainties by name.

  • width_factor (float) – Multiplier on NLSQ uncertainty for prior width.

Return type:

dict[str, ParameterScaling]

Returns:

Dict mapping parameter name to ParameterScaling.

heterodyne.optimization.cmc.scaling.log_scaling_factors(scalings)[source]

Log all scaling factors for debugging.

Emits a header at INFO level, then per-parameter details at DEBUG.

Parameters:

scalings (dict[str, ParameterScaling]) – Mapping of parameter name to its scaling specification.

Return type:

None

heterodyne.optimization.cmc.scaling.transform_initial_values_to_z(initial_values, scalings)[source]

Transform initial values from physics space to z-space.

Only transforms parameters present in both initial_values and scalings.

Parameters:
  • initial_values (dict[str, float]) – Physics-space values keyed by parameter name.

  • scalings (dict[str, ParameterScaling]) – Scaling specifications keyed by parameter name.

Return type:

dict[str, float]

Returns:

Dict with keys {name}_z mapped to normalized z-space values.

heterodyne.optimization.cmc.scaling.transform_samples_from_z(samples, scalings)[source]

Transform MCMC samples from z-space back to physics space.

Input keys must end with "_z"; the suffix is stripped to recover the original parameter name. Only parameters with a matching entry in scalings are transformed.

Parameters:
  • samples (dict[str, Array]) – Z-space sample arrays keyed by {name}_z.

  • scalings (dict[str, ParameterScaling]) – Scaling specifications keyed by parameter name.

Return type:

dict[str, Array]

Returns:

Dict with original parameter names mapped to physics-space arrays.

Data Preparation

Data preparation utilities for CMC analysis.

Handles validation, JAX conversion, and sharding of correlation data for large-dataset MCMC workflows.

class heterodyne.optimization.cmc.data_prep.ShardingStrategy[source]

Bases: Enum

Strategy for partitioning prepared data into shards.

RANDOM

Randomly assign data points to shards with a fixed seed.

CONTIGUOUS

Split along the time axis into contiguous blocks.

STRATIFIED

Stratified by time range so each shard covers all epochs.

ANGLE_BALANCED

Each shard receives proportional representation from every phi angle, preventing heterogeneous sub-posteriors.

RANDOM = 'random'
CONTIGUOUS = 'contiguous'
STRATIFIED = 'stratified'
ANGLE_BALANCED = 'angle_balanced'
class heterodyne.optimization.cmc.data_prep.PreparedData[source]

Bases: object

Validated and structured data container for CMC/NUTS sampling.

All arrays are kept as NumPy for compatibility with both JAX and SciPy backends. The caller converts to JAX inside the sampler.

c2_data

Flattened observed correlation values, shape (n_total,).

weights

Per-element likelihood weights, shape (n_total,) or None when uniform weighting is used.

time_array

Unique time values used to build the time grid, shape (n_times,).

phi_angles

Per-element phi angles (radians or degrees), shape (n_total,).

q

Wavevector magnitude in Å⁻¹.

dt

Frame time step in seconds.

metadata

Arbitrary key/value pairs (configuration, provenance, …).

n_angles

Number of unique phi angles.

n_times

Length of time_array.

c2_data: ndarray
weights: ndarray | None
time_array: ndarray
phi_angles: ndarray
q: float
dt: float
metadata: dict[str, Any]
n_angles: int = 0
n_times: int = 0
__init__(c2_data, weights, time_array, phi_angles, q, dt, metadata=<factory>, n_angles=0, n_times=0)
class heterodyne.optimization.cmc.data_prep.PooledCMCData[source]

Bases: object

Pooled multi-phi data container for joint CMC (homodyne parity).

Mirrors homodyne.optimization.cmc.data_prep.PreparedData so heterodyne can run ONE NUTS pass conditioned on all phi angles with shared physics parameters. The pooled layout is (n_total,)-flat over angles + (t1, t2) grid; phi_indices maps each pooled point to its angle in phi_unique.

data

Pooled C2 values, shape (n_total,).

t1, t2

Pooled time coordinates, shape (n_total,).

phi

Pooled phi angles per point, shape (n_total,).

phi_unique

Sorted unique phi angles, shape (n_phi,).

phi_indices

Per-point index into phi_unique, shape (n_total,).

n_total

Number of pooled data points (after diagonal filtering).

n_phi

Cardinality of phi_unique.

noise_scale

MAD-based noise estimate, used to centre the sampled sigma prior.

data: ndarray
t1: ndarray
t2: ndarray
phi: ndarray
phi_unique: ndarray
phi_indices: ndarray
n_total: int
n_phi: int
noise_scale: float
__init__(data, t1, t2, phi, phi_unique, phi_indices, n_total, n_phi, noise_scale)
heterodyne.optimization.cmc.data_prep.validate_pooled_data(data, t1, t2, phi)[source]

Validate pooled arrays have matching length, finite values, sane shapes.

Return type:

None

heterodyne.optimization.cmc.data_prep.extract_phi_info(phi)[source]

Return (phi_unique, phi_indices) with tolerance-aware matching.

Mirrors homodyne.optimization.cmc.data_prep.extract_phi_info exactly. For n_phi <= 256 uses argmin(|phi - phi_unique|, axis=1) so float rounding doesn’t misassign points to neighbour angles; for larger phi counts falls back to searchsorted with a left-neighbour check.

Return type:

tuple[ndarray, ndarray]

heterodyne.optimization.cmc.data_prep.prepare_mcmc_data(data, t1, t2, phi, filter_diagonal=True)[source]

Validate + filter pooled XPCS data for joint multi-phi CMC.

Mirrors homodyne.optimization.cmc.data_prep.prepare_mcmc_data. With filter_diagonal=True (default), removes t1 == t2 rows using an epsilon-based comparison sized to the smallest positive dt in the arrays. The diagonal is loaded and plotted but excluded from likelihood fitting — same boundary contract as the t=0 row/col.

Parameters:
  • data (ndarray) – Pooled (n_total,) arrays. Each entry is one (C2 value, t1, t2, phi) observation.

  • t1 (ndarray) – Pooled (n_total,) arrays. Each entry is one (C2 value, t1, t2, phi) observation.

  • t2 (ndarray) – Pooled (n_total,) arrays. Each entry is one (C2 value, t1, t2, phi) observation.

  • phi (ndarray) – Pooled (n_total,) arrays. Each entry is one (C2 value, t1, t2, phi) observation.

  • filter_diagonal (bool) – When True, drop entries where |t1 - t2| <= ε.

Returns:

Filtered + indexed container ready for the NumPyro model.

Return type:

PooledCMCData

heterodyne.optimization.cmc.data_prep.shard_pooled_random(prepared, num_shards=None, max_points_per_shard=None, max_shards=100, seed=42)[source]

Shard pooled data into ~equal random subsets (homodyne parity).

Used when there is a single phi angle, or as the fallback for multi-angle data when angle-balanced sharding is not requested. Every data point lands in exactly one shard — no subsampling, no point loss.

Mirrors homodyne.optimization.cmc.data_prep.shard_data_random.

Parameters:
  • prepared (PooledCMCData) – Pooled multi-phi data container.

  • num_shards (int | None) – Explicit shard count. When None, derived from max_points_per_shard (ceil division), else 1.

  • max_points_per_shard (int | None) – Target points per shard used to derive num_shards when it is not given.

  • max_shards (int) – Hard cap on shard count; when exceeded the shard size grows so all data still fits in max_shards shards.

  • seed (int) – Seed for the index shuffle (reproducible assignment).

Return type:

list[PooledCMCData]

Returns:

List of PooledCMCData shards covering all points.

heterodyne.optimization.cmc.data_prep.shard_pooled_angle_balanced(prepared, num_shards=None, max_points_per_shard=None, max_shards=500, min_angle_coverage=0.8, seed=42)[source]

Shard pooled data with proportional per-angle coverage (homodyne parity).

Preferred strategy for multi-angle datasets (n_phi > 1). Each shard samples proportionally from every phi angle so sub-posteriors stay homogeneous — pure random sharding can leave shards with uneven angle coverage, producing high cross-shard parameter variance that Consensus MC then combines incorrectly.

Mirrors homodyne.optimization.cmc.data_prep.shard_data_angle_balanced.

Parameters:
  • prepared (PooledCMCData) – Pooled multi-phi data container.

  • num_shards (int | None) – Explicit shard count. When None, derived from max_points_per_shard (ceil division), else max(1, n_phi).

  • max_points_per_shard (int | None) – Target points per shard used to derive num_shards when it is not given.

  • max_shards (int) – Hard cap on shard count.

  • min_angle_coverage (float) – Fraction of angles each shard should contain; shards below this are logged as a diagnostic (not an error).

  • seed (int) – Seed for per-angle shuffles (reproducible assignment).

Return type:

list[PooledCMCData]

Returns:

List of PooledCMCData shards with balanced angle coverage. Falls back to shard_pooled_random() when n_phi == 1.

heterodyne.optimization.cmc.data_prep.prepare_cmc_data(c2_data, sigma=None, weights=None)[source]

Validate and convert correlation data to JAX arrays.

Performs shape, dtype, NaN, and monotonicity checks on the input correlation matrix before transferring to JAX device memory.

Parameters:
  • c2_data (ndarray | Array) – Observed two-time correlation matrix. Must be 2-D and square (or 1-D for single-time slices).

  • sigma (ndarray | float | None) – Measurement uncertainty. Scalar broadcasts to all elements; array must match c2_data shape. If None, returns None and the caller is responsible for estimation.

  • weights (ndarray | Array | None) – Optional per-element weights for likelihood weighting. Must match c2_data shape if provided.

Return type:

tuple[Array, Array | float, Array | None]

Returns:

Tuple of (c2_jax, sigma_jax, weights_jax) ready for the NumPyro model. sigma_jax is the scalar or array sigma (or None passthrough when input is None). weights_jax is None when no weights are provided.

Raises:

ValueError – If data contains NaN, has mismatched shapes, or violates expected structure.

heterodyne.optimization.cmc.data_prep.create_shard_grid(n_times, n_shards)[source]

Create time-index partitions for sharding a correlation matrix.

Divides the time axis into approximately equal chunks so that each shard can be processed independently (e.g., for consensus Monte Carlo on very large two-time matrices).

Parameters:
  • n_times (int) – Number of time points along one axis of the correlation matrix.

  • n_shards (int) – Number of shards to create. Must be >= 1 and <= n_times.

Return type:

list[tuple[int, int]]

Returns:

List of (start, stop) index pairs (half-open intervals) that partition range(n_times) into n_shards contiguous chunks.

Raises:

ValueError – If n_shards < 1 or n_shards > n_times.

heterodyne.optimization.cmc.data_prep.shard_correlation_data(c2_data, shard_grid)[source]

Split a two-time correlation matrix into shards along both axes.

Each shard is a sub-block of the full correlation matrix defined by the row and column index ranges in shard_grid. Only diagonal blocks (same row and column shard) are returned, as off-diagonal blocks carry cross-shard correlations that are handled separately in the consensus step.

Parameters:
  • c2_data (ndarray | Array) – Full two-time correlation matrix of shape (N, N).

  • shard_grid (list[tuple[int, int]]) – List of (start, stop) index pairs from create_shard_grid().

Return type:

list[Array]

Returns:

List of JAX arrays, one per shard, each of shape (stop - start, stop - start).

Raises:

ValueError – If c2_data is not 2-D or if shard indices are out of bounds.

heterodyne.optimization.cmc.data_prep.merge_shard_results(shard_results)[source]

Combine per-shard posterior samples via simple concatenation.

For consensus Monte Carlo, this implements the naive pooling strategy where samples from each shard’s sub-posterior are concatenated. The caller may apply further weighting or density product corrections.

All shards must contain the same set of parameter names.

Parameters:

shard_results (list[dict[str, ndarray]]) – List of sample dictionaries, one per shard. Each dict maps parameter names to 1-D arrays of posterior draws.

Return type:

dict[str, ndarray]

Returns:

Merged dictionary with concatenated samples for each parameter.

Raises:

ValueError – If shard results are empty or have mismatched keys.

heterodyne.optimization.cmc.data_prep.prepare_data(raw_data, config=None)[source]

Validate, normalise, and package raw XPCS data for CMC sampling.

This is the main entry point for converting a raw data dictionary (as produced by the XPCS loader) into a PreparedData instance suitable for NUTS or CMC workflows.

Parameters:
  • raw_data (dict[str, Any]) –

    Dictionary with at least the following keys:

    • "c2_data" – array-like, shape (n_angles, n_t, n_t) or (n_t, n_t).

    • "phi_angles" – 1-D array of azimuthal angles (degrees or radians), length n_angles.

    • "time_array" – 1-D monotonically increasing time axis.

    • "q" – scalar wavevector magnitude (Å⁻¹).

    • "dt" – scalar frame time step (seconds).

    Optional keys:

    • "weights" – array matching c2_data, per-element likelihood weights.

    • "mask" – boolean array matching c2_data; True where data should be excluded.

  • config (dict[str, Any] | None) –

    Optional configuration dictionary. Recognised keys:

    • "normalize_weights" (bool, default True) – rescale weights so their mean equals 1.

    • "require_positive_diagonal" (bool, default True) – raise if any diagonal element <= 0.

Return type:

PreparedData

Returns:

PreparedData ready for create_shards().

Raises:
  • ValueError – If required keys are missing, arrays have unexpected shapes, or data contains NaN / non-finite values.

  • KeyError – If a required key is absent from raw_data.

heterodyne.optimization.cmc.data_prep.create_shards(prepared_data, n_shards, strategy=ShardingStrategy.ANGLE_BALANCED, *, seed=42)[source]

Split a PreparedData instance into n_shards sub-datasets.

Each shard is itself a PreparedData with the same q, dt, and time_array as the parent but containing only a subset of the pooled data points.

Parameters:
  • prepared_data (PreparedData) – Source data returned by prepare_data().

  • n_shards (int) – Number of shards to create. Must be >= 1.

  • strategy (ShardingStrategy) – Splitting strategy (see ShardingStrategy).

  • seed (int) – Random seed used by stochastic strategies (RANDOM, ANGLE_BALANCED).

Return type:

list[PreparedData]

Returns:

List of n_shards PreparedData instances.

Raises:

ValueError – If n_shards < 1 or strategy is unsupported.

heterodyne.optimization.cmc.data_prep.validate_shard_data(shard)[source]

Validate a single shard for common data quality issues.

Checks performed:

  • No NaN or non-finite values in c2_data.

  • Shape consistency between c2_data, phi_angles, and weights (when present).

  • At least one data point.

  • Positive values in the subset of elements corresponding to the diagonal of the original two-time matrix.

Parameters:

shard (PreparedData) – PreparedData shard to validate.

Raises:

ValueError – On any detected integrity issue.

Return type:

None

heterodyne.optimization.cmc.data_prep.estimate_shard_memory(shard)[source]

Estimate the device memory footprint of a shard in bytes.

Counts all NumPy arrays stored in the shard, using their actual nbytes attribute. This is a lower bound because JAX may add internal buffers during JIT compilation, but it is accurate enough for pre-flight capacity checks.

Parameters:

shard (PreparedData) – PreparedData shard.

Return type:

int

Returns:

Estimated memory in bytes.

CMC I/O

Full shard-level and aggregate I/O pipeline for CMC results. Supports NPZ sample arrays, ArviZ InferenceData, JSON parameter/diagnostics files, and fitted-data arrays.

Shard I/O for CMC (Consensus Monte Carlo) results.

Provides functions to persist and reload posterior samples as .npz archives and ArviZ InferenceData objects as NetCDF files.

heterodyne.optimization.cmc.io.save_shard_results(results, output_dir, shard_id)[source]

Save posterior samples for a single shard as a .npz archive.

The file is written to <output_dir>/shard_<shard_id>.npz.

Parameters:
  • results (dict[str, ndarray]) – Mapping of parameter name to sample array.

  • output_dir (str | Path) – Directory in which to save the archive.

  • shard_id (int) – Integer shard identifier.

Return type:

Path

Returns:

Path to the saved file.

heterodyne.optimization.cmc.io.load_shard_results(output_dir, shard_id)[source]

Load posterior samples for a single shard.

Parameters:
  • output_dir (str | Path) – Directory containing shard archives.

  • shard_id (int) – Integer shard identifier.

Return type:

dict[str, ndarray]

Returns:

Mapping of parameter name to sample array.

Raises:

FileNotFoundError – If the shard file does not exist.

heterodyne.optimization.cmc.io.list_shards(output_dir)[source]

Discover saved shard IDs in output_dir.

Scans for files matching shard_<N>.npz and returns a sorted list of the integer shard IDs.

Parameters:

output_dir (str | Path) – Directory to scan.

Return type:

list[int]

Returns:

Sorted list of shard IDs found.

heterodyne.optimization.cmc.io.save_inference_data(idata, path)[source]

Save an ArviZ InferenceData object as a NetCDF file.

Parameters:
  • idata (object) – ArviZ InferenceData instance.

  • path (str | Path) – Destination file path (should end in .nc).

Return type:

Path

Returns:

Path to the saved file.

Raises:

ImportError – If arviz is not installed.

heterodyne.optimization.cmc.io.load_inference_data(path)[source]

Load an ArviZ InferenceData object from a NetCDF file.

Parameters:

path (str | Path) – Path to a NetCDF file previously written by save_inference_data().

Return type:

object

Returns:

ArviZ InferenceData object.

Raises:
heterodyne.optimization.cmc.io.save_samples_npz(result, output_path)[source]

Save posterior samples as a compressed NPZ archive (ArviZ-compatible).

Shape stored: posterior_samples is (n_chains, n_samples, n_params).

Return type:

None

heterodyne.optimization.cmc.io.load_samples_npz(input_path)[source]

Load samples NPZ and return a plain Python dict.

Raises:
Return type:

dict[str, Any]

heterodyne.optimization.cmc.io.samples_to_arviz(samples_data)[source]

Convert loaded samples dict to ArviZ InferenceData.

Parameters:

samples_data (dict[str, Any]) – Data returned by load_samples_npz().

Return type:

Any

heterodyne.optimization.cmc.io.save_fitted_data_npz(result, c2_exp, c2_fitted, c2_fitted_std, t1, t2, phi_angles, q, output_path)[source]

Save fitted C2 data in NLSQ-compatible NPZ format.

Return type:

None

heterodyne.optimization.cmc.io.save_parameters_json(result, output_path)[source]

Save posterior statistics per parameter to JSON.

NaN -> null, Inf -> “Infinity”/”-Infinity”.

Return type:

None

heterodyne.optimization.cmc.io.save_diagnostics_json(result, output_path, warnings=None)[source]

Save convergence diagnostics to JSON.

Return type:

None

heterodyne.optimization.cmc.io.save_all_results(result, output_dir, c2_exp=None, c2_fitted=None, c2_fitted_std=None, t1=None, t2=None, phi_angles=None, q=None)[source]

Save all CMC result files and return a dict of {key: path}.

Return type:

dict[str, Path]

CMC Plotting

ArviZ-backed diagnostic plots for CMC posterior samples.

ArviZ-based MCMC diagnostic plots for CMC results.

All functions accept ArviZ InferenceData objects and return matplotlib figures. arviz is imported with a try/except guard so that the rest of the package does not hard-depend on it.

heterodyne.optimization.cmc.plotting.plot_trace_summary(idata, var_names=None, figsize=None)[source]

ArviZ trace plot with marginal posteriors.

Parameters:
  • idata (object) – ArviZ InferenceData object.

  • var_names (list[str] | None) – Subset of variable names to plot (None for all).

  • figsize (tuple[float, float] | None) – Optional figure size override.

Return type:

Figure

Returns:

Matplotlib Figure.

heterodyne.optimization.cmc.plotting.plot_pair_plot(idata, var_names=None, divergences=True)[source]

ArviZ pair plot with optional divergence markers.

Parameters:
  • idata (object) – ArviZ InferenceData object.

  • var_names (list[str] | None) – Subset of variable names.

  • divergences (bool) – Whether to overlay divergence markers.

Return type:

Figure

Returns:

Matplotlib Figure.

heterodyne.optimization.cmc.plotting.plot_posterior_predictive(idata, c2_data, times, ax=None)[source]

Overlay posterior-predictive draws on experimental data.

If idata contains a posterior_predictive group with a variable named "c2_pred", the 5th/95th percentile envelope is drawn. Otherwise a message is displayed.

Parameters:
  • idata (object) – ArviZ InferenceData object.

  • c2_data (ndarray) – 2-D experimental correlation matrix, shape (N, N).

  • times (ndarray) – 1-D time array.

  • ax (Axes | None) – Optional existing Axes.

Return type:

Axes

Returns:

The matplotlib Axes.

heterodyne.optimization.cmc.plotting.plot_diagnostics_summary(idata)[source]

Combined R-hat, ESS, and BFMI diagnostic panels.

Creates a three-panel figure: 1. R-hat per parameter (bar chart). 2. Bulk ESS per parameter (bar chart). 3. BFMI per chain (bar chart).

Parameters:

idata (object) – ArviZ InferenceData object.

Return type:

Figure

Returns:

Matplotlib Figure.

heterodyne.optimization.cmc.plotting.plot_forest(idata, output_dir, var_names=None, figsize=DEFAULT_FIGSIZE, dpi=DEFAULT_DPI)[source]

Save an ArviZ forest plot to output_dir / forest_plot.png.

Shows posterior intervals (94% HDI by default) for each parameter. Homodyne-parity helper.

Return type:

Path

heterodyne.optimization.cmc.plotting.plot_energy(idata, output_dir, figsize=(10, 6), dpi=DEFAULT_DPI)[source]

Save an ArviZ energy plot to output_dir / energy_plot.png.

Falls back gracefully when sample_stats lacks an energy field by writing a placeholder image with an explanatory message. Homodyne parity helper that handles NumPyro’s potential_energy naming.

Return type:

Path

heterodyne.optimization.cmc.plotting.plot_autocorr(idata, output_dir, var_names=None, figsize=DEFAULT_FIGSIZE, dpi=DEFAULT_DPI)[source]

Save an ArviZ autocorrelation plot. Homodyne-parity helper.

Return type:

Path

heterodyne.optimization.cmc.plotting.plot_rank(idata, output_dir, var_names=None, figsize=DEFAULT_FIGSIZE, dpi=DEFAULT_DPI)[source]

Save an ArviZ rank plot. Helps detect chain-mixing issues.

Return type:

Path

heterodyne.optimization.cmc.plotting.plot_ess(idata, output_dir, var_names=None, figsize=(10, 6), dpi=DEFAULT_DPI)[source]

Save an ArviZ ESS-evolution plot.

Return type:

Path

heterodyne.optimization.cmc.plotting.generate_diagnostic_plots(idata, output_dir, var_names=None, dpi=DEFAULT_DPI)[source]

Generate the full homodyne-parity diagnostic plot suite.

Writes forest_plot.png, energy_plot.png, autocorr_plot.png, rank_plot.png, and ess_plot.png to output_dir. Individual failures are isolated — one broken plot does not abort the others.

Parameters:
  • idata (object) – ArviZ InferenceData.

  • output_dir (Path) – Directory to write the PNGs into. Created if missing.

  • var_names (list[str] | None) – Optional explicit subset of parameter names to include. Defaults to physical (non-scaling) sites.

  • dpi (int) – Resolution of each PNG.

Return type:

dict[str, Path]

Returns:

Mapping plot_kind -> output_path for each plot that succeeded.

Backends

Execution backends for running NUTS chains across CPU cores, HPC cluster nodes, and multi-device pjit configurations. All backends share the MCMCBackend protocol and are selected automatically via select_backend() based on the chain_method setting.

Backend Selection

Abstract base and factory for MCMC execution backends.

Includes Consensus Monte Carlo utilities for combining posteriors from independent MCMC shards via inverse-variance (precision) weighting.

class heterodyne.optimization.cmc.backends.base.BackendCapabilities[source]

Bases: object

Static description of what an MCMC backend can do.

Used by the backend selection logic and resource estimation code to choose the best available backend at runtime without instantiating every candidate.

supports_sharding

True if the backend can distribute data shards across workers or devices.

supports_parallel_chains

True if chains can run concurrently (e.g. via pmap or a worker pool).

max_parallel_shards

Maximum number of shards the backend can handle simultaneously. 1 means strictly sequential.

supports_sharding: bool = False
supports_parallel_chains: bool = True
max_parallel_shards: int = 1
__init__(supports_sharding=False, supports_parallel_chains=True, max_parallel_shards=1)
class heterodyne.optimization.cmc.backends.base.MCMCBackend[source]

Bases: Protocol

Protocol for MCMC execution backends.

Each backend wraps NumPyro’s MCMC machinery with a CPU execution strategy (sequential single-device or parallel multi-device).

run(model, config, rng_key, init_params=None)[source]

Run MCMC sampling and return posterior samples.

Parameters:
  • model (Callable[..., Any]) – NumPyro model function (callable with no required args).

  • config (CMCConfig) – CMC configuration with sampling hyperparameters.

  • rng_key (Array) – JAX PRNG key for reproducibility.

  • init_params (dict[str, Array] | None) – Optional initial parameter values per chain. Keys are parameter names; values have shape (num_chains,).

Return type:

dict[str, Any]

Returns:

Dictionary mapping parameter names to sample arrays. Each array has shape (num_samples * num_chains,) for ungrouped samples, matching NumPyro’s default get_samples() behavior.

__init__(*args, **kwargs)
class heterodyne.optimization.cmc.backends.base.CMCBackend[source]

Bases: ABC

Abstract base class for CMC execution backends.

Concrete subclasses implement CPU MCMC execution strategies (sequential, multi-device parallel, worker-pool, etc.). Subclasses must override run, get_capabilities, validate_resources, estimate_memory, and cleanup.

abstractmethod run(model, config, rng_key, init_params=None)[source]

Run MCMC sampling and return posterior samples.

Parameters:
  • model (Callable[..., Any]) – NumPyro model function.

  • config (CMCConfig) – CMC configuration with sampling hyperparameters.

  • rng_key (Array) – JAX PRNG key for reproducibility.

  • init_params (dict[str, Array] | None) – Optional per-chain initial parameter values.

Return type:

dict[str, Any]

Returns:

Dictionary mapping parameter names to flat sample arrays.

abstractmethod get_capabilities()[source]

Return a static description of this backend’s capabilities.

Return type:

BackendCapabilities

Returns:

Frozen BackendCapabilities dataclass.

abstractmethod validate_resources()[source]

Check that required hardware and software resources are available.

Raises:

RuntimeError – If a required resource (device, library, memory) is unavailable.

Return type:

None

abstractmethod estimate_memory(n_data, n_params, n_chains)[source]

Estimate peak memory consumption for a sampling run.

The estimate is intentionally conservative (upper-bound) to help callers decide whether to proceed or reduce chain count / shard size.

Parameters:
  • n_data (int) – Number of data points per shard.

  • n_params (int) – Number of model parameters.

  • n_chains (int) – Number of MCMC chains to run.

Return type:

float

Returns:

Estimated peak memory in gigabytes.

abstractmethod cleanup()[source]

Release any resources held by this backend.

Called after sampling is complete. Implementations should be idempotent (safe to call more than once).

Return type:

None

heterodyne.optimization.cmc.backends.base.select_backend(config)[source]

Select the appropriate MCMC backend.

Selection order (mirrors homodyne select_backend semantics):

  1. backend_name == "pbs"PBSBackend (raises if qsub is not on PATH).

  2. backend_name == "multiprocessing" or legacy alias "jax"MultiprocessingBackend.

  3. backend_name == "pjit"PjitBackend.

  4. backend_name == "cpu"CPUBackend.

  5. backend_name == "auto" (default): heuristic (MultiprocessingBackend if n_chains >= 3 and at least 2 physical workers, then PjitBackend when len(jax.devices()) > 1, else CPUBackend).

Parameters:

config (CMCConfig) – CMC configuration.

Return type:

MCMCBackend

Returns:

An instantiated backend ready for run().

Raises:

ValueErrorbackend_name not in the supported set.

class heterodyne.optimization.cmc.backends.base.ShardPosterior[source]

Bases: object

Posterior summary from a single MCMC shard.

mean

Parameter means, shape (n_params,).

covariance

Covariance matrix, shape (n_params, n_params).

n_samples

Number of effective posterior samples in this shard.

shard_id

Optional identifier for logging / diagnostics.

mean: ndarray
covariance: ndarray
n_samples: int = 0
shard_id: int = 0
__init__(mean, covariance, n_samples=0, shard_id=0)
heterodyne.optimization.cmc.backends.base.consensus_mc(shard_posteriors)[source]

Combine shard posteriors using Consensus Monte Carlo.

Each shard’s posterior is summarised by its mean and covariance. The combined posterior is the precision-weighted average:

Λ_combined = Σ_k Λ_k          (sum of precisions)
μ_combined = Λ_combined⁻¹ Σ_k Λ_k μ_k

This is exact when the sub-posteriors are Gaussian and the prior factorises across shards (the “embarrassingly parallel” regime).

Parameters:

shard_posteriors (list[ShardPosterior]) – List of ShardPosterior, one per shard. All must have the same dimensionality.

Return type:

tuple[ndarray, ndarray]

Returns:

Tuple of (combined_mean, combined_covariance) where combined_mean has shape (n_params,) and combined_covariance has shape (n_params, n_params).

Raises:

ValueError – If fewer than 1 shard is provided or shapes are inconsistent.

heterodyne.optimization.cmc.backends.base.robust_consensus_mc(shard_posteriors, *, outlier_sigma=3.0)[source]

Combine shard posteriors with outlier-resistant weighting.

Like consensus_mc() but first identifies and downweights outlier shards whose means deviate from the cross-shard median by more than outlier_sigma standard deviations.

Outlier detection uses the median absolute deviation (MAD) of per-shard means for each parameter. Shards flagged as outliers on any parameter have their precision scaled by 1 / n_shards (i.e. they contribute but don’t dominate).

Parameters:
  • shard_posteriors (list[ShardPosterior]) – List of ShardPosterior.

  • outlier_sigma (float) – Number of MAD-scaled deviations beyond which a shard is considered an outlier. Default 3.0.

Return type:

tuple[ndarray, ndarray]

Returns:

Tuple of (combined_mean, combined_covariance).

Raises:

ValueError – If fewer than 2 shards are provided (need ≥ 2 for robust statistics) or shapes are inconsistent.

heterodyne.optimization.cmc.backends.base.combine_shard_samples(shard_samples, *, method='consensus_mc', chunk_size=500, seed=42)[source]

Combine raw posterior samples from multiple CMC shards.

Homodyne CMC parity wrapper. Each shard contributes a dictionary of per-parameter posterior draws (typically shape (n_chains, n_samples)); the function returns a single combined dictionary with the same per-parameter shape.

Pathways:
  • Single shard — returned unchanged.

  • K ≤ chunk_size — single-pass precision-weighted combination on per-shard (mean, variance) summaries with non-finite filtering and degenerate-shard exclusion (variance < 1e-6 × median variance).

  • K > chunk_size — moment-accumulation across chunks, single Gaussian draw at the end. Avoids the recursive precision-inflation bug that arose from re-combining synthetic intermediate samples.

Parameters:
  • shard_samples (list[dict[str, ndarray]]) – List of per-shard sample dicts. All must share the same parameter-name set and per-parameter shape (C, S).

  • method (str) – Combination method. "consensus_mc" / "robust_consensus_mc" / "weighted_gaussian" / "auto" all map to precision-weighted Gaussian recombination at this granularity. "simple_average" averages without precision weighting.

  • chunk_size (int) – Threshold for hierarchical mode. Default 500 keeps per-step peak memory bounded.

  • seed (int) – PRNG seed for the synthetic Gaussian draw at the end.

Return type:

dict[str, ndarray]

Returns:

Combined samples dict with the same keys/shapes as the per-shard inputs.

Raises:

ValueError – If shard_samples is empty or shards disagree on parameter names.

heterodyne.optimization.cmc.backends.base.combine_shard_samples_bimodal(shard_samples, *, cluster_param=None, method='consensus_mc', seed=42)[source]

Mode-aware combination — cluster shards by posterior mode then combine within cluster.

Homodyne CMC parity helper for multimodal posteriors. Uses a simple 2-means clustering on per-shard posterior means of cluster_param (default: the first parameter) to partition shards into two modes, then runs combine_shard_samples() within each cluster.

Parameters:
  • shard_samples (list[dict[str, ndarray]]) – Per-shard sample dictionaries.

  • cluster_param (str | None) – Parameter name to use for clustering. None picks the first parameter alphabetically.

  • method (str) – Combination method passed to combine_shard_samples().

  • seed (int) – PRNG seed for clustering tiebreaker and final draws.

Return type:

dict[str, dict[str, ndarray]]

Returns:

Mapping cluster_id -> combined_samples_dict. Cluster ids are "mode_low" and "mode_high" ordered by mean value of cluster_param. If clustering fails (e.g. fewer than 2 shards per mode), all shards are combined into a single "mode_low" bucket.

CPU Backend (Sequential / Vectorized)

CPU-optimized MCMC execution backend.

Runs NUTS chains sequentially (one at a time) to avoid memory pressure on CPU-only machines where all chains share the same memory pool.

class heterodyne.optimization.cmc.backends.cpu_backend.CPUBackend[source]

Bases: CMCBackend

CPU-optimized MCMC backend using sequential chain execution.

Runs each MCMC chain one at a time via NumPyro’s chain_method="sequential" to keep peak memory usage proportional to a single chain. This is the recommended backend for single-device CPU machines.

run(model, config, rng_key, init_params=None)[source]

Run NUTS sampling with sequential chain execution.

Parameters:
  • model (Callable[..., Any]) – NumPyro model function.

  • config (CMCConfig) – CMC configuration.

  • rng_key (Array) – JAX PRNG key.

  • init_params (dict[str, Array] | None) – Optional per-chain initial values.

Return type:

dict[str, Any]

Returns:

Dictionary of posterior samples from all chains.

Raises:

RuntimeError – If MCMC sampling fails.

get_capabilities()[source]

Return CPU backend capabilities.

The CPU backend runs chains sequentially (one at a time) and does not support cross-device sharding.

Return type:

BackendCapabilities

Returns:

BackendCapabilities reflecting sequential CPU execution.

validate_resources()[source]

Verify that CPU resources are available for sampling.

Checks that at least one JAX CPU device is accessible.

Raises:

RuntimeError – If no CPU device is found via jax.devices.

Return type:

None

estimate_memory(n_data, n_params, n_chains)[source]

Estimate peak CPU memory for a single sequential chain.

Because chains run sequentially only one chain’s state is live at any moment, so n_chains does not multiply peak usage.

The formula accounts for:

  • Flat parameter storage per draw: n_params float64 scalars.

  • Gradient / momentum buffers: same size as parameters.

  • Sample storage for the completed chain: num_samples * n_params.

  • Data residual buffer: n_data float64 scalars.

  • A conservative overhead multiplier (_CPU_MEMORY_OVERHEAD_FACTOR) for JAX tracing buffers and NumPyro auxiliary state.

Parameters:
  • n_data (int) – Number of data points per shard.

  • n_params (int) – Number of model parameters.

  • n_chains (int) – Number of chains (not used for sequential backend; included for API uniformity).

Return type:

float

Returns:

Estimated peak memory in gigabytes.

cleanup()[source]

Release CPU backend resources.

The CPU backend holds no persistent state beyond what JAX and NumPyro manage internally, so this is a no-op. Included for API parity with other backends (PjitBackend, worker-pool).

Return type:

None

Multiprocessing Backend

Parallel shard evaluation across CPU cores using Python multiprocessing. Each worker process runs an independent NUTS chain on its own shard, with full JAX persistent compilation cache support.

Multiprocessing backend for CMC sharded MCMC execution.

This module provides parallel NUTS execution using Python’s multiprocessing module for CPU-based parallelism across CMC shards. Each shard runs as a separate spawned process with its own JAX initialization, avoiding JAX shared-state issues across forked processes.

Key design decisions:

  • mp_context="spawn" (not fork): JAX cannot be safely shared across fork. Spawned workers re-initialize JAX from scratch.

  • All NumPyro imports inside worker functions: spawn safety requires that no JAX/NumPyro state exists at import time in the child process.

  • Shared memory for common data: SharedDataManager places config, parameter-space state, and per-shard arrays in shared memory once, avoiding redundant serialization overhead through spawn.

  • LPT scheduling: shards dispatched highest-cost-first to minimize tail latency on identical parallel workers.

  • Heartbeat thread inside each worker: emits liveness pings so the parent can detect frozen processes and apply heartbeat_timeout.

  • Adaptive polling: poll interval grows when no shard has completed recently, shrinking CPU overhead during long-running shards.

Optimizations carried over from homodyne v2.22.2:

  • Batch PRNG key generation: pre-generate all shard keys in one JAX call.

  • Per-shard shared memory (packed format): 4 segments total regardless of shard count, avoiding fd exhaustion.

  • deque for pending shards: O(1) popleft instead of O(n) list.pop(0).

  • Persistent compilation cache via jax.config.update (env var alone insufficient in JAX 0.8+, min_compile_time lowered to 0).

This backend is selected when config.num_chains >= 3, or when config.backend_name == "multiprocessing".

class heterodyne.optimization.cmc.backends.multiprocessing.ArraySpec[source]

Bases: object

Schema for one packed shared-memory array key.

expected_dtype

Documented dtype. Advisory only — the packer uses the actual array’s dtype. Logged in shape-mismatch errors.

allow_none

Whether a caller may pass None for this key, in which case it is stored as a zero-length sentinel.

description

Human-readable label, surfaced in error messages so shape/size mismatches identify the offending physical array (e.g. “two-time correlation matrix”, not just “c2_data”).

expected_dtype: str
allow_none: bool
description: str
__init__(expected_dtype, allow_none, description)
class heterodyne.optimization.cmc.backends.multiprocessing.SharedDataManager[source]

Bases: object

Manages shared memory blocks for data common to all CMC shards.

Uses multiprocessing.shared_memory to share config dicts, parameter- space state, initial values, and per-shard arrays across spawned worker processes, avoiding redundant serialisation per shard.

Serialization note: uses pickle internally for trusted internal dicts only (CMCConfig.to_dict(), parameter-space dict). This matches the existing multiprocessing behaviour which also serialises all process arguments. External/untrusted data is never serialised here.

Must be used as a context manager or cleanup() called in a finally block to avoid leaked shared memory segments on Linux.

_shared_blocks

All allocated SharedMemory segments.

_refs

Named references returned to callers.

__init__()[source]
create_shared_bytes(name, data)[source]

Store raw bytes in a shared memory segment.

Parameters:
  • name (str) – Logical name for this block (used for bookkeeping only).

  • data (bytes) – Bytes to copy into shared memory.

Return type:

dict[str, Any]

Returns:

Reference dict with shm_name, size, and type keys.

create_shared_array(name, array)[source]

Store a numpy array in a shared memory segment.

Parameters:
  • name (str) – Logical name for this block.

  • array (ndarray) – Array to copy into shared memory (contiguous float64).

Return type:

dict[str, Any]

Returns:

Reference dict with shm_name, shape, dtype, and type keys.

create_shared_dict(name, d)[source]

Serialise a trusted internal dict into shared memory.

Only used for CMCConfig.to_dict() and parameter-space dicts. External/untrusted data is never passed here.

Parameters:
  • name (str) – Logical name for this block.

  • d (dict[str, Any]) – Dict to serialise into shared memory.

Return type:

dict[str, Any]

Returns:

Reference dict (same as create_shared_bytes()).

create_shared_shard_arrays(shard_data_list)[source]

Place per-shard numpy arrays into packed shared memory.

Instead of creating one SharedMemory segment per array per shard (n_shards * 4 = many file descriptors), this concatenates all shard arrays for each key into a single shared memory block. Only len(_SHARD_ARRAY_KEYS) segments are created regardless of shard count.

Parameters:

shard_data_list (list[dict[str, Any]]) – List of shard data dicts, each containing numpy arrays keyed by _SHARD_ARRAY_KEYS plus a scalar noise_scale. Arrays for sigma and weights may be None, in which case a zero-length sentinel is stored.

Return type:

list[dict[str, Any]]

Returns:

List of lightweight shard references (shm names + offsets). Each ref dict is small enough to serialise cheaply through spawn.

cleanup()[source]

Release all shared memory blocks.

Idempotent — safe to call more than once. Must be called in a finally block to avoid leaked segments.

Return type:

None

heterodyne.optimization.cmc.backends.multiprocessing.run_joint_pooled_shards_parallel(payloads, *, n_workers, num_chains, progress_bar=True, per_shard_timeout=7200)[source]

Run pooled-model shard payloads across a spawn process pool.

Reuses the proven _init_worker_jax initializer (float64, compilation cache, OpenMP thread pinning). Each element of payloads is the kwargs dict for heterodyne.optimization.cmc.core._joint_pooled_nuts_run(). Returns the per-shard CMCResult objects in input order.

Progress is reported two ways so long runs never block silently: a tqdm bar (interactive terminals) AND periodic logger.info heartbeats (visible in log files, where tqdm is not). The heartbeat fires every heartbeat seconds while waiting and reports elapsed time, completed/total shards, and time since the last completion.

per_shard_timeout bounds the wait: if no shard completes within per_shard_timeout seconds (every in-flight worker has exceeded the budget), the pool is terminated and the remaining shards are returned as failed placeholders. This prevents the het_491ee368 failure mode where a divergence-storming shard ran 21 000 s (3x the 7200 s budget) unkilled. Terminated shards become failed sub-posteriors that Consensus MC drops via its convergence_passed gate — the run degrades, it does not hang.

Return type:

list[Any]

class heterodyne.optimization.cmc.backends.multiprocessing.LPTScheduler[source]

Bases: object

Longest Processing Time scheduler for load balancing across cores.

Assigns shards to workers based on estimated computation time using the LPT (Longest Processing Time first) heuristic. The highest-cost shards are dispatched first so that the remaining shards finishing last are the cheapest, minimising overall tail latency.

This is a simple greedy scheduler; it does not account for real-time feedback about actual execution durations.

n_workers

Number of parallel workers.

_shard_order

Deque of shard indices sorted by descending cost.

__init__(shard_costs, n_workers)[source]

Initialise the LPT scheduler.

Parameters:
  • shard_costs (list[float]) – Estimated cost (positive float) per shard. Higher is more expensive.

  • n_workers (int) – Number of parallel workers.

classmethod from_shard_data(shard_data_list, n_workers, n_params=_N_PARAMS_HETERODYNE, n_samples=1000)[source]

Build an LPTScheduler from raw shard data dicts.

Cost is estimated via _estimate_shard_time() for each shard.

Parameters:
  • shard_data_list (list[dict[str, Any]]) – Shard dicts with "c2_data" and "noise_scale" keys.

  • n_workers (int) – Number of parallel workers.

  • n_params (int) – Number of model parameters (default: 14).

  • n_samples (int) – Expected total MCMC draws per shard.

Return type:

LPTScheduler

Returns:

Configured LPTScheduler.

next_shard()[source]

Pop and return the next shard index to dispatch.

Return type:

int | None

Returns:

Next shard index (highest remaining cost), or None when all shards have been dispatched.

remaining()[source]

Return the number of shards not yet dispatched.

Return type:

int

as_deque()[source]

Return the internal order deque (consumed by dispatch loop).

Return type:

deque[int]

class heterodyne.optimization.cmc.backends.multiprocessing.MultiprocessingBackend[source]

Bases: CMCBackend

CMC backend that parallelises NUTS across shards via spawned processes.

Each shard runs as an independent Python process so that JAX is initialised fresh per shard — avoiding the shared-state issues that arise when forking a process that already has a JAX runtime loaded.

Shared data (config, parameter space, initial values, per-shard arrays) is placed in SharedDataManager once in the parent and accessed via _load_shared_* in each child, minimising serialisation overhead through spawn.

The run() method provides the standard single-shard CMCBackend contract (sequential chain execution, no subprocess overhead). For multi-shard CMC, use run_shards(), which orchestrates the full parallel dispatch loop.

n_workers

Number of concurrent worker processes.

spawn_method

Multiprocessing start method (always "spawn").

_shared_mgr

Active SharedDataManager during run_shards(); None otherwise.

__init__(n_workers=None, spawn_method='spawn')[source]

Initialise the multiprocessing backend.

Parameters:
  • n_workers (int | None) – Number of worker processes. Defaults to the estimated physical core count, capped to avoid oversubscription.

  • spawn_method (str) – Process start method. Must be "spawn" for JAX safety. "fork" is explicitly unsupported.

Raises:

ValueError – If spawn_method="fork" is requested.

n_workers: int
spawn_method: str
run(model, config, rng_key, init_params=None)[source]

Run NUTS sampling for a single shard (standard CMCBackend contract).

For multi-shard CMC, call run_shards() instead. This method provides API parity with CPUBackend and PjitBackend using sequential chain execution and no subprocess overhead.

Parameters:
  • model (Callable[..., Any]) – NumPyro model function.

  • config (CMCConfig) – CMC configuration.

  • rng_key (Array) – JAX PRNG key.

  • init_params (dict[str, Array] | None) – Optional per-chain initial values.

Return type:

dict[str, Any]

Returns:

Dictionary of posterior samples from all chains.

Raises:

RuntimeError – If MCMC sampling fails.

get_capabilities()[source]

Return multiprocessing backend capabilities.

Return type:

BackendCapabilities

Returns:

BackendCapabilities indicating sharding support, parallel shards equal to n_workers.

validate_resources()[source]

Check that CPU resources and multiprocessing are available.

Raises:

RuntimeError – If no JAX CPU device is found or if the multiprocessing module cannot create a spawn context.

Return type:

None

estimate_memory(n_data, n_params, n_chains)[source]

Estimate peak memory for all concurrent workers combined.

Conservative upper bound: each worker holds one chain’s live state (params + momentum + gradients) plus the data buffer. Workers run chains sequentially so n_chains does not multiply within a single worker.

Parameters:
  • n_data (int) – Number of data points per shard.

  • n_params (int) – Number of model parameters.

  • n_chains (int) – Number of MCMC chains per shard (not used for per-worker estimate; included for API uniformity).

Return type:

float

Returns:

Estimated peak memory in gigabytes.

cleanup()[source]

Release shared memory and any other resources.

Idempotent — safe to call multiple times.

Return type:

None

run_shards(shards, config, initial_values=None, parameter_space=None, prior_width_multiplier=1.0, nlsq_uncertainties=None, nlsq_prior_width_factor=2.0, progress_bar=True)[source]

Run NUTS in parallel across all CMC shards.

Orchestrates the full parallel dispatch loop:

  1. Allocate shared memory for config, parameter space, initial values, and per-shard arrays.

  2. Pre-generate all PRNG keys in the parent process.

  3. Dispatch shards to worker processes in LPT order.

  4. Drain the result queue with adaptive polling.

  5. Enforce per-shard and heartbeat timeouts.

  6. Validate and return successful shard results.

Parameters:
  • shards (list[dict[str, Any]]) – List of shard dicts. Each must contain at minimum c2_data (numpy array). Optional keys: sigma, t, weights, noise_scale, q, dt, phi_angle, contrast, offset, reparam_config_dict.

  • config (CMCConfig) – CMC configuration with NUTS hyperparameters and timeout settings.

  • initial_values (dict[str, Any] | None) – Optional NLSQ warm-start values shared across all shards.

  • parameter_space (Any | None) – Optional ParameterSpace instance. Its internal config dict is serialised into shared memory.

  • progress_bar (bool) – Whether to show a tqdm progress bar.

Return type:

list[dict[str, Any]]

Returns:

List of validated successful result dicts, one per succeeded shard. Each dict contains shard_idx, samples, n_chains, n_samples, param_names, extra_fields, duration, and stats.

Raises:
  • ValueError – If shards is empty.

  • RuntimeError – If all shards fail, or if the success rate falls below config.min_success_rate.

is_available()[source]

Check whether this backend can run on the current platform.

Return type:

bool

Returns:

True if the spawn multiprocessing context is available.

PBS/HPC Backend

Submits per-shard NUTS chains as PBS job-array tasks for use on HPC cluster nodes (e.g. ALCF Polaris, NERSC Perlmutter).

PBS/Torque job submission backend for Consensus Monte Carlo.

Submits per-shard MCMC sampling as PBS batch jobs, collects results from completed jobs, and combines them. Designed for HPC clusters running PBS Professional or Torque where each shard runs on a separate node.

Usage:

from heterodyne.optimization.cmc.backends.pbs import PBSBackend, PBSConfig

cfg = PBSConfig(queue="large", walltime="04:00:00", nodes=1, ppn=8)
backend = PBSBackend(pbs_config=cfg)
samples = backend.run(model_fn, cmc_config, rng_key)
class heterodyne.optimization.cmc.backends.pbs.PBSConfig[source]

Bases: object

Configuration for PBS/Torque job submission.

queue

Target PBS queue name (e.g. "batch").

walltime

Maximum wall-clock time in HH:MM:SS format.

nodes

Number of nodes per shard job.

ppn

Processors per node.

memory

Memory request per node (e.g. "4gb").

python_executable

Python interpreter accessible in the PBS job environment (full path or bare name like "python3").

working_dir

Directory for temporary files. Defaults to tempfile.gettempdir() when None.

extra_pbs_directives

Raw #PBS lines injected verbatim after the standard resource block.

poll_interval

Seconds between qstat polls.

max_retries

Re-submission attempts for failed shards.

cleanup_on_success

Delete temporary files after successful runs.

queue: str = 'batch'
walltime: str = '01:00:00'
nodes: int = 1
ppn: int = 1
memory: str = '4gb'
python_executable: str = 'python'
working_dir: str | None = None
extra_pbs_directives: list[str]
poll_interval: float = 30.0
max_retries: int = 2
cleanup_on_success: bool = True
__post_init__()[source]

Codex W5: validate every shell-bound field at construction time.

Misconfigured PBS jobs fail here (cheap) instead of after a qsub round-trip (expensive) or — worse — after the malformed script runs. Each validator raises ValueError with a message that names the field and the offending characters.

Return type:

None

__init__(queue='batch', walltime='01:00:00', nodes=1, ppn=1, memory='4gb', python_executable='python', working_dir=None, extra_pbs_directives=<factory>, poll_interval=30.0, max_retries=2, cleanup_on_success=True)
class heterodyne.optimization.cmc.backends.pbs.ShardResult[source]

Bases: object

Result from a single shard MCMC job.

shard_id

Zero-based shard index.

samples

Posterior samples keyed by parameter name.

job_id

PBS job identifier string.

success

True when sampling completed without errors.

error_message

Populated when success is False.

shard_id: int
samples: dict[str, Any]
job_id: str
success: bool = True
error_message: str = ''
__init__(shard_id, samples, job_id, success=True, error_message='')
class heterodyne.optimization.cmc.backends.pbs.PBSBackend[source]

Bases: CMCBackend

PBS/Torque backend for distributed CMC sampling.

Each data shard is submitted as an independent PBS batch job. The main process polls qstat until all jobs terminate, then reads per-shard .npz result files and concatenates the samples.

Parameters:

pbs_config (PBSConfig | None) – PBS resource and scheduling options. Defaults to a PBSConfig() with queue="batch" when None.

__init__(pbs_config=None)[source]
run(model, config, rng_key, init_params=None)[source]

Submit per-shard PBS jobs, wait for completion, return combined samples.

Parameters:
  • model (Callable[..., Any]) – NumPyro model function. Must be picklable.

  • config (CMCConfig) – CMC configuration with sampling hyperparameters.

  • rng_key (Array) – JAX PRNG key used to derive per-shard integer seeds.

  • init_params (dict[str, Array] | None) – Optional per-chain initial parameter values.

Return type:

dict[str, Any]

Returns:

Dictionary mapping parameter names to concatenated numpy arrays.

Raises:
  • RuntimeError – If PBS is unavailable or any shard fails.

  • TimeoutError – If config.shard_timeout_seconds is exceeded.

get_capabilities()[source]

Return PBS backend capabilities.

Return type:

BackendCapabilities

validate_resources()[source]

Check that qsub is on PATH.

Raises:

RuntimeError – If qsub is not found.

Return type:

None

estimate_memory(n_data, n_params, n_chains)[source]

Estimate peak memory per PBS node (in GB) for a single shard.

Because each shard runs on a separate node, the result is per-node only. The estimate is conservative (upper-bound) to prevent OOM kills.

Return type:

float

cleanup(job_ids=None)[source]

Cancel active jobs and delete temporary files.

Parameters:

job_ids (list[str] | None) – PBS job IDs to cancel (best-effort). None skips cancellation but still removes temporary files.

Return type:

None

submit_shard(shard_data, model_fn, config_dict, shard_id, seed)[source]

Serialize shard payload, write PBS script, and submit via qsub.

Parameters:
  • shard_data (dict[str, Any]) – Auxiliary shard-specific data (e.g. init_params).

  • model_fn (Callable[..., Any]) – Picklable NumPyro model function.

  • config_dict (dict[str, Any]) – Flat NUTS hyperparameter dict.

  • shard_id (int) – Zero-based shard index (drives file naming).

  • seed (int) – Integer random seed for this shard.

Return type:

str

Returns:

PBS job ID string.

Raises:

RuntimeError – If qsub fails.

wait_for_jobs(job_ids, timeout=None)[source]

Poll qstat until all jobs reach a terminal state.

Jobs absent from qstat (purged from the accounting database) are treated as complete; the .npz file determines success or failure.

Parameters:
  • job_ids (list[str]) – PBS job IDs in shard order.

  • timeout (float | None) – Maximum seconds to wait (None = unlimited).

Return type:

list[ShardResult]

Returns:

List of ShardResult in the same order as job_ids.

Raises:

TimeoutError – If timeout elapses before all jobs finish.

run_shards(shards, model_fn, config, seeds)[source]

Submit and collect an explicit list of pre-partitioned shards.

Use this when the caller has already split the data and needs to attach per-shard payloads. For the standard pipeline use run().

Parameters:
  • shards (list[dict[str, Any]]) – Per-shard data dicts passed verbatim to submit_shard.

  • model_fn (Callable[..., Any]) – Picklable NumPyro model function.

  • config (CMCConfig) – CMC configuration.

  • seeds (list[int]) – One integer seed per shard.

Return type:

list[ShardResult]

Returns:

List of ShardResult in shard order.

Raises:

Pjit Backend (Multi-Device)

Experimental backend that shards chains across jax.devices() via pjit. Currently CPU-only because heterodyne ships no GPU support, but the interface is device-agnostic.

Multi-device parallel MCMC backend using JAX sharding.

Distributes NUTS chains across multiple CPU devices using JAX’s modern sharding API (jax.sharding), available since JAX 0.4.1. This replaces the deprecated jax.experimental.pjit with the stable jax.jit + sharding annotation approach. Heterodyne is CPU-only.

class heterodyne.optimization.cmc.backends.pjit.PjitBackend[source]

Bases: CMCBackend

Multi-device parallel MCMC backend using JAX sharding.

Distributes NUTS chains across all available JAX devices. Each device runs a subset of the requested chains in parallel via NumPyro’s vectorized chain execution.

When only a single device is available, this backend transparently falls back to running all chains on that device (equivalent to the single-device chain_method="parallel").

This backend uses the modern jax.sharding API (stable since JAX 0.4.1), not the deprecated jax.experimental.pjit.

run(model, config, rng_key, init_params=None)[source]

Run NUTS sampling distributed across multiple devices.

Splits chains across available devices. Each device group runs independently, and results are gathered and concatenated.

Parameters:
  • model (Callable[..., Any]) – NumPyro model function.

  • config (CMCConfig) – CMC configuration with sampling hyperparameters.

  • rng_key (Array) – JAX PRNG key for reproducibility.

  • init_params (dict[str, Array] | None) – Optional per-chain initial parameter values.

Return type:

dict[str, Any]

Returns:

Dictionary mapping parameter names to flat sample arrays.

Raises:

RuntimeError – If sampling fails on any device.

get_capabilities()[source]

Return capabilities for multi-device parallel execution.

Return type:

BackendCapabilities

Returns:

BackendCapabilities with sharding support flags.

validate_resources()[source]

Check that multiple devices are available for sharding.

Logs a warning (but does not raise) when only one device is detected, since the backend can still function in single-device mode.

Raises:

RuntimeError – If no JAX devices are available at all.

Return type:

None

estimate_memory(n_data, n_params, n_chains)[source]

Estimate peak memory per device in GB.

Each device holds a fraction of the chains. Memory per chain is approximately: n_params x n_data x 8 bytes (float64) for the likelihood evaluation, plus sample storage.

Parameters:
  • n_data (int) – Number of data points per shard.

  • n_params (int) – Number of model parameters.

  • n_chains (int) – Total number of MCMC chains.

Return type:

float

Returns:

Estimated peak memory in GB per device.

cleanup()[source]

Release resources. No-op for JAX-managed devices.

Return type:

None

heterodyne.optimization.cmc.backends.pjit.combine_shard_samples(shard_results)[source]

Combine posterior samples from multiple device shards.

Concatenates sample arrays along the first axis (samples dimension).

Parameters:

shard_results (list[dict[str, Any]]) – List of sample dictionaries, one per device shard. Each dict maps parameter names to numpy/JAX arrays.

Return type:

dict[str, Any]

Returns:

Combined dictionary with concatenated samples.

Raises:

ValueError – If shard_results is empty.

Persistent Worker Pool

Long-lived worker pool that amortises per-shard JAX JIT compilation overhead. Recommended for large analyses with many shards.

Worker pool backend for multi-shard CMC execution.

Spawns persistent workers that each run MCMC on assigned shards. Amortizes JAX/NumPyro initialization overhead across tasks.

class heterodyne.optimization.cmc.backends.worker_pool.WorkerPoolBackend[source]

Bases: object

Persistent worker pool for multi-shard CMC execution.

Distributes MCMC shards across a pool of worker processes. Each worker runs one chain per shard, and results are combined.

__init__(n_workers=None)[source]

Initialize with optional worker count.

Parameters:

n_workers (int | None) – Number of workers. Defaults to physical_cores - 1 (reserving one core for the main process). Uses detect_cpu_info() for accurate physical core detection; falls back to os.cpu_count() // 2 if detection fails.

property n_workers: int
get_name()[source]
Return type:

str

static should_use_pool(n_shards, n_workers)[source]

Check if pool execution is beneficial.

Homodyne parity: pool is used whenever there are at least 3 shards, regardless of the worker count. The pool’s startup cost is amortised across shards, not workers, so the gate only needs to reject trivially small shard counts.

Parameters:
  • n_shards (int) – Number of data shards.

  • n_workers (int) – Available workers (unused; kept for API stability).

Return type:

bool

Returns:

True when parallelism would be beneficial.

run(model, config, rng_key, init_params=None)[source]

Run MCMC via worker pool.

Parameters:
  • model (Callable[..., Any]) – NumPyro model function.

  • config (CMCConfig) – CMC configuration.

  • rng_key (Array) – JAX PRNG key (used to generate per-shard seeds).

  • init_params (dict[str, Array] | None) – Optional initial values (not used in pool mode).

Return type:

dict[str, Any]

Returns:

Combined posterior samples from all workers.

heterodyne.optimization.cmc.backends.worker_pool.should_use_persistent_pool(n_shards, n_workers)[source]

Homodyne-parity gate: persistent pool helps when there are >=3 shards.

Return type:

bool

class heterodyne.optimization.cmc.backends.worker_pool.PersistentWorkerPool[source]

Bases: object

Persistent process pool for CMC shard dispatch (homodyne parity).

Workers are spawned once, perform a one-time initialization (e.g. JAX warm-up / JIT priming) under control of worker_init_fn, signal readiness, and then process tasks from a per-worker queue until a None sentinel is received. The pool exposes a round-robin submit(), a blocking get_result(), and a deterministic shutdown() that drains task queues, joins processes, and terminates any that refuse to exit.

The class supports the context-manager protocol; __exit__ calls shutdown().

Parameters:
  • n_workers (int) – Number of persistent worker processes.

  • worker_fn (Callable[..., dict[str, Any] | None]) – Picklable module-level callable invoked per task. Signature worker_fn(task: dict, **init_kwargs) -> dict | None. Returning None skips putting a result on the shared queue, which is useful when worker_fn already manages its own result emission.

  • worker_init_kwargs (dict[str, Any] | None) – One-time kwargs forwarded to both worker_init_fn (if provided) and every worker_fn call. Must be picklable.

  • worker_init_fn (Callable[..., None] | None) – Optional one-time initializer with signature init_fn(worker_id: int, **init_kwargs) -> None. Use it for expensive setup like JAX backend selection or JIT pre-warming.

  • startup_timeout (float) – Maximum seconds to wait for all workers to signal readiness before continuing with whatever subset is ready.

__init__(n_workers, worker_fn, worker_init_kwargs=None, worker_init_fn=None, startup_timeout=120.0)[source]
property n_workers: int
property result_queue: Queue
is_alive()[source]
Return type:

bool

submit(task)[source]

Round-robin submit task to the next worker’s queue.

Return type:

None

get_result(timeout=300.0)[source]

Block until a result is available; raises queue.Empty on timeout.

Return type:

dict[str, Any]

results_pending()[source]

True when the shared result queue has at least one entry.

Return type:

bool

shutdown(timeout=10.0)[source]

Send None sentinels to all workers and join them.

Return type:

None