CMC (Bayesian)¶
Bayesian posterior sampling via NumPyro NUTS, with NLSQ-derived warm-start initialization, configurable priors, reparameterization, and ArviZ-compatible convergence diagnostics.
Core Fitting¶
Core CMC fitting functions for heterodyne Bayesian analysis.
Includes the original single-run fit_cmc_jax and the new sharded
Consensus Monte Carlo entry point fit_cmc_sharded, plus all supporting
helpers for shard creation, prior tempering, and posterior combination.
- heterodyne.optimization.cmc.core.CMC_F0_DEGEN_THRESHOLD: float = 0.1
Sample fraction below which alpha_sample/D0_sample/D_offset_sample are unidentifiable, causing 100% shard failure (het_bb97531f failure mode).
- heterodyne.optimization.cmc.core.CMC_ALPHA_SINGULARITY: float = -1.5
alpha_sample value at which J_sample(t) ∝ t^α has a non-integrable singularity at t→0, collapsing NUTS step-size for the sample group.
- heterodyne.optimization.cmc.core.CMC_F0_SAFE_ZONE: float = 0.2
Target value for
f0when auto-clamping a degenerate warm-start. Set to 2× the degeneracy threshold (not just above it) so NUTS leapfrog steps cannot easily reflect back into the unidentifiable region. The borderline valueCMC_F0_DEGEN_THRESHOLD + 0.01 = 0.11used in het_a10cf27e produced 47/47 shard failure even after clamping; this wider margin is the het_a10cf27e fix.
- heterodyne.optimization.cmc.core.CMC_ALPHA_SAFE_ZONE: float = -1.0
Target value for
alpha_samplewhen auto-clamping. Sits 0.5 above the t^α singularity (vs the previous 0.1 borderline value). The wider margin keeps NUTS step-size adaptation stable when warm-starting from the degenerate region.
- heterodyne.optimization.cmc.core.fit_cmc_jax(model, c2_data, phi_angle=0.0, config=None, sigma=None, nlsq_result=None, t_override=None, priors_override=None, prior_width_multiplier=1.0)[source]
Fit heterodyne model using Consensus Monte Carlo.
Uses NumPyro’s NUTS sampler for Bayesian posterior inference.
- Parameters:
model (
HeterodyneModel) – HeterodyneModel with configured parametersphi_angle (
float) – Detector phi angle (degrees)config (
CMCConfig|None) – CMC configuration (default if None)sigma (
ndarray|float|None) – Measurement uncertainty (estimated if None)nlsq_result (
NLSQResult|None) – Optional NLSQ result for warm-startingt_override (
ndarray|None) – Optional time array replacingmodel.tfor model construction. Used byfit_cmc_shardedto pass shard time slices. IfNone, falls back tomodel.t.priors_override (
dict|None) – Optional dict of pre-built NumPyro distributions keyed by parameter name. When provided, these distributions replace the defaultspace.priorsfor matching parameters. Used byfit_cmc_shardedto inject tempered shard priors into the non-reparam model path.prior_width_multiplier (
float) – Scalar multiplier applied to thescaleof each reparam-path prior AFTERnlsq_prior_width_factorscaling. Default 1.0 (no change). Used byfit_cmc_shardedto widen reparam priors bysqrt(K).
- Return type:
CMCResult- Returns:
CMCResult with posterior samples and diagnostics
- heterodyne.optimization.cmc.core.fit_cmc_sharded(model, c2_data, phi_angle=0.0, config=None, sigma=None, nlsq_result=None, num_shards=4, sharding_strategy='random', shard_seed=None)[source]
Fit heterodyne model using sharded Consensus Monte Carlo.
Splits the observed c2 matrix into
num_shardsindependent data subsets, runs NUTS on each shard sub-posterior (sequentially), then combines the shard posteriors via inverse-variance weighted consensus.Prior tempering is applied automatically: each shard’s prior distribution is widened by
sqrt(num_shards)(i.e.,prior^(1/K)) while sigma is passed unscaled. This is the correct Consensus Monte Carlo approach (Scott et al., 2016).- Parameters:
model (
HeterodyneModel) – HeterodyneModel with configured parameters.c2_data (
ndarray|Array) – Observed two-time correlation matrix (N x N).phi_angle (
float) – Detector phi angle (degrees).config (
CMCConfig|None) – CMC configuration (defaults to CMCConfig()).sigma (
ndarray|float|None) – Measurement uncertainty (estimated if None).nlsq_result (
NLSQResult|None) – Optional NLSQ result for warm-starting each shard.num_shards (
int) – Number of data shards (K). Must be >= 2.sharding_strategy (
str) – One of"random"(default) or"contiguous". Random sharding breaks temporal autocorrelation between shards. Contiguous sharding uses diagonal time-blocks, which preserves the two-time structure within each shard.shard_seed (
int|None) – Integer seed for deterministic shard assignment. IfNone, a random seed is drawn from the OS.
- Return type:
CMCResult- Returns:
CMCResult with combined posterior and per-shard diagnostics stored in
result.metadata["shard_diagnostics"].- Raises:
ValueError – If inputs fail validation or
num_shards < 2.
- heterodyne.optimization.cmc.core.fit_cmc_multi_phi(model, c2_data, phi_angles, config=None, nlsq_results=None, sigma=None)[source]
Joint multi-phi CMC entry point (homodyne parity).
Fits the pooled multi-phi data with shared 14 physics parameters and per-angle contrast / offset. Mirrors
homodyne’s_fit_mcmc_jax_impl: small datasets run a single NUTS pass; large datasets (n_totalabove the single-shard limit, or whennum_shards/max_points_per_shardis set explicitly) are sharded and combined by Consensus Monte Carlo — NUTS is O(n) per leapfrog step, so a single pass over millions of pooled points is intractable. Algorithm:Pool
c2_data(shape(n_phi, N, N)or(N, N)for n_phi=1) into flat arrays(data, t1, t2, phi)of lengthn_phi * N * N.Run
prepare_mcmc_data()to filter the diagonal and build aPooledCMCDatacontainer withphi_uniqueandphi_indices(homodyne layout).Compute per-point grid indices
i1_indices/i2_indicesviasearchsortedagainstmodel.t.Build the joint NumPyro model
xpcs_model_heterodyne_scaled()(per-angle sampled scaling + shared physics + single pooled likelihood with t=0 boundary mask).Run NUTS with the configured chains / warmup / samples.
Return a single
CMCResultwith shared-physics posterior + per-angle scaling posteriors (mean_contrast/mean_offsetarrays of lengthn_phi).
The returned
CMCResultreflects one joint inference — every angle contributes to the same physics-parameter posterior, exactly as in homodyne.- Parameters:
model (
HeterodyneModel) – ConfiguredHeterodyneModelwhose time gridmodel.tdefines the (N,) axis the pooled c2 was flattened from.c2_data (
ndarray|Array) – Experimental c2 of shape(n_phi, N, N)(multi-angle) or(N, N)(single-angle, treated as n_phi=1).phi_angles (
ndarray|list[float]) – Detector phi angles in degrees, lengthn_phi.config (
CMCConfig|None) –CMCConfig.Noneuses defaults.nlsq_results (
list[NLSQResult] |None) – Optional per-angle NLSQ warm-start. Currently only used to log warm-start status; future phases will translate to init_to_value.sigma (
ndarray|float|None) – Optional measurement uncertainty estimate;Nonetriggers a MAD-based estimate viaprepare_mcmc_data().
- Returns:
Joint multi-phi result.
parameter_nameslists the 14 physics parameters followed bycontrast_0..contrast_{n_phi-1}andoffset_0..offset_{n_phi-1}.- Return type:
CMCResult
- heterodyne.optimization.cmc.core.fit_mcmc_jax(data, t1, t2, phi, q, L, analysis_mode, method='mcmc', cmc_config=None, initial_values=None, parameter_space=None, dt=None, output_dir=None, progress_bar=True, run_id=None, nlsq_result=None, **kwargs)[source]
Homodyne-parity entry point for heterodyne CMC.
Mirrors
homodyne.optimization.cmc.fit_mcmc_jax’s pooled-array call signature and routes to heterodyne’s nativefit_cmc_jax/fit_cmc_sharded. This adapter exists so cross-package CLI / driver code that follows homodyne’s pooled-data convention can call heterodyne’s CMC pipeline without reshaping by hand.Heterodyne’s native API is
fit_cmc_jax(model, c2_data, phi_angle, config, ...)— new heterodyne code should prefer that directly.- Parameters:
data (
ndarray) – Pooled C2 values and time/angle coordinates, all shape(n_total,).(t1, t2)must be a flattened regular meshgrid for the reverse reshape to succeed.t1 (
ndarray) – Pooled C2 values and time/angle coordinates, all shape(n_total,).(t1, t2)must be a flattened regular meshgrid for the reverse reshape to succeed.t2 (
ndarray) – Pooled C2 values and time/angle coordinates, all shape(n_total,).(t1, t2)must be a flattened regular meshgrid for the reverse reshape to succeed.phi (
ndarray) – Pooled C2 values and time/angle coordinates, all shape(n_total,).(t1, t2)must be a flattened regular meshgrid for the reverse reshape to succeed.q (
float) – Wavevector magnitude (Å⁻¹).L (
float) – Stator-rotor gap. Accepted for homodyne parity; heterodyne’s physics model uses absolute time scaling, not L-normalised dimensionless time.analysis_mode (
str) – Accepted for homodyne parity; heterodyne always uses its two-component model regardless of this value.method (
str) – Accepted for homodyne parity; consumed by heterodyne’s native pipeline where applicable.output_dir (
Any|None) – Accepted for homodyne parity; consumed by heterodyne’s native pipeline where applicable.progress_bar (
bool) – Accepted for homodyne parity; consumed by heterodyne’s native pipeline where applicable.run_id (
str|None) – Accepted for homodyne parity; consumed by heterodyne’s native pipeline where applicable.cmc_config (
dict[str,Any] |CMCConfig|None) – CMC configuration. Dicts are converted viaCMCConfig.from_dict.initial_values (
dict[str,float] |None) – Initial parameter values applied to the constructed model.parameter_space (
Any|None) – Pre-built ParameterSpace. WhenNone, a default one is built fromDEFAULT_REGISTRY.dt (
float|None) – Time step. WhenNone, inferred fromnp.diff(np.unique(t1 ∪ t2)).nlsq_result (
NLSQResult|dict|None) – Optional NLSQ warm-start. Dicts are ignored with a warning since heterodyne’s native warm-start requires anNLSQResultinstance.
- Returns:
Result of the CMC fit.
- Return type:
CMCResult- Raises:
ValueError – If input shapes mismatch or the (t1, t2) pooled grid is not recoverable.
NotImplementedError – If pooled data contains multiple distinct phi angles; call this adapter once per angle, or use
heterodyne.cli.optimization_runner.run_cmcfor orchestrated multi-angle CMC.
- heterodyne.optimization.cmc.core.run_cmc_analysis(model, c2_data, config=None, **kwargs)[source]
Convenience wrapper around
fit_cmc_jax()(homodyne parity).Accepts the same arguments as
fit_cmc_jax()and delegates directly.- Return type:
CMCResult
Configuration¶
Configuration for CMC (Consensus Monte Carlo) analysis.
This module defines CMCConfig, a comprehensive dataclass covering all aspects of the heterodyne CMC pipeline: sharding strategy, backend selection, NUTS sampling parameters, convergence validation thresholds, reparameterization, prior tempering, shard combination, and run identification.
The heterodyne model has 14 free parameters (vs. 7 in homodyne). All auto-scaling formulas account for this increased dimensionality.
- heterodyne.optimization.cmc.config.DENSE_MASS_WARMUP_FLOOR: int = 1500
NUTS warmup floor when
dense_mass=True. Dense mass-matrix adaptation needs ≥ 100 steps per dimension; for the 14-parameter heterodyne model, 1500 warmup steps is the minimum that yields healthy R-hat in regression.- Type:
Rule 12
- heterodyne.optimization.cmc.config.effective_warmup_floor(requested, *, dense_mass, fast_warmup=False)[source]
Return the safe warmup count, applying the Rule 12 floor when needed.
When
dense_mass=Truethe NUTS dense mass-matrix adaptation requires at leastDENSE_MASS_WARMUP_FLOORsteps; otherwise the requested value is returned unchanged.fast_warmup=Trueopts out of the floor and emits a one-shot warning — intended for CI fast-mode only.
- class heterodyne.optimization.cmc.config.CMCConfig[source]
Bases:
objectComprehensive configuration for Consensus Monte Carlo (CMC) analysis.
CMC splits a large dataset into K shards, runs NUTS independently on each shard, then combines the resulting posteriors using a consensus algorithm. This dataclass controls every knob across the full pipeline.
Parameters are grouped into logical sections, matching the structure of the
to_dict/from_dictserialization format:enable — master on/off switch and dataset-size gate.
per_angle — how to handle the phi (angle) dimension.
sharding — shard count, strategy, and size bounds.
backend_config — worker backend and checkpoint settings.
per_shard_mcmc — NUTS hyper-parameters and adaptive scaling.
validation — convergence thresholds and abort conditions.
nlsq — NLSQ warm-start and prior-width configuration.
prior_tempering — scale priors by 1/K for shard consistency.
combination — posterior combination algorithm and success criteria.
timeout — per-shard and heartbeat time limits.
reparameterization — parameter transforms and bimodality guards.
run_id — optional identifier for checkpoint namespacing.
- enable
Master switch.
"auto"enables CMC whenn_points >= min_points_for_cmc.True/"always"forces CMC regardless of dataset size.False/"never"disables CMC entirely.
- min_points_for_cmc
Minimum number of data points required before CMC is activated under
enable="auto". Below this threshold the pipeline falls back to full NUTS.
- per_angle_mode
Strategy for handling the angle (phi) dimension.
"auto"selects automatically based onn_phiandconstant_scaling_threshold.
- constant_scaling_threshold
Minimum number of phi angles required before switching from
"constant"to"individual"mode whenper_angle_mode="auto".
- sharding_strategy
How to partition data across shards:
"stratified"preserves angle distributions,"random"shuffles globally,"contiguous"uses contiguous memory blocks.
- num_shards
Number of shards
K."auto"derivesKfrom dataset size, phi count, andmin_points_per_shard/min_points_per_param.
- max_points_per_shard
Upper bound on shard size.
"auto"disables the cap.
- min_points_per_shard
Lower bound on shard size; prevents degenerate under-determined shards.
- min_points_per_param
Minimum ratio of points-to-parameters per shard (heterodyne default: 14).
- backend_name
Worker backend.
"auto"selects based on available CPU devices and core count.
- enable_checkpoints
Persist intermediate shard results to
checkpoint_dir.
- checkpoint_dir
Directory for shard checkpoint files.
- chain_method
Whether to run chains in
"parallel"or"sequential"order within each shard worker.
- num_warmup
Number of NUTS warm-up (burn-in) steps per chain. Default 1500 provides ~100 steps per parameter for the 14-parameter heterodyne model when
dense_mass=True; lower values leave mass-matrix adaptation incomplete and produce high R-hat / divergence storms.
- num_samples
Number of posterior draws per chain after warm-up.
- num_chains
Number of independent MCMC chains per shard.
- target_accept_prob
Target acceptance probability for the dual-averaging NUTS step-size adaptation (must be in
[0.5, 0.99]). Default 0.90 keeps step sizes small enough to traverse the (D0, alpha) funnel without divergence cascades; reduce only if you understand the geometry.
- max_tree_depth
Maximum binary tree depth for NUTS leapfrog integration.
- seed
Base random seed for deterministic reproducibility.
- dense_mass
Use a dense (full-covariance) mass matrix. Expensive but more accurate for highly correlated posteriors.
- init_strategy
NUTS initialisation strategy.
- adaptive_sampling
Scale
num_warmup/num_samplesdown proportionally when shard size is below_REFERENCE_SHARD_SIZE.
- min_warmup
Floor on adaptive warm-up count.
- min_samples
Floor on adaptive sample count.
- max_r_hat
Maximum acceptable Gelman-Rubin statistic; chains with
R-hat > max_r_hatare flagged as not converged.
- min_ess
Minimum effective sample size per parameter.
- min_bfmi
Minimum Bayesian Fraction of Missing Information (energy diagnostic).
- max_divergence_rate
Maximum fraction of divergent transitions before a shard is rejected.
- require_nlsq_warmstart
Abort if an NLSQ warm-start was requested but unavailable.
- allow_degenerate_warmstart
When
False(default) an earlyRuntimeErroris raised before dispatching any shards if the warm-start is in a regime that is guaranteed to cause 100% shard failure (f0 < 0.10 or alpha_sample < -1.5 — the het_bb97531f failure mode). Set toTrueto bypass the abort and let NUTS attempt the run anyway; useful when you want to see exactly how bad the posteriors are or when you have increasednum_warmupsubstantially (≥ 2000).
- max_parameter_cv
Maximum allowed coefficient of variation across chains for any parameter; guards against pathological multi-modal posteriors.
- heterogeneity_abort
Abort the entire CMC run if shards produce incompatible posteriors (detected via KL divergence or parameter-CV checks).
- use_nlsq_warmstart
Initialise each shard’s NUTS chains from the NLSQ MAP estimate.
- use_nlsq_informed_priors
Centre Gaussian priors on NLSQ estimates scaled by
nlsq_prior_width_factor.
- nlsq_prior_width_factor
Scale factor applied to NLSQ parameter uncertainties when constructing informed priors.
- prior_tempering
Divide log-prior by
K(number of shards) so that the combined posterior approximates the full-data prior exactly.
- combination_method
Algorithm used to combine shard posteriors.
- min_success_rate
Minimum fraction of shards that must converge; run fails below this.
- min_success_rate_warning
Fraction below which a warning is emitted even if the run succeeds.
- per_shard_timeout
Wall-clock seconds allowed per shard before it is cancelled.
- heartbeat_timeout
Seconds of silence from a worker before it is declared dead.
- use_reparam
Apply parameter reparameterisations (e.g. log-transforms) in NumPyro.
- reparameterization_d_total
Reparameterise
d_total = d_fast + d_slowas an unconstrained sum.
- reparameterization_log_gamma
Reparameterise
gammaon a log scale to enforce positivity.
- bimodal_min_weight
Minimum mixture weight for the minor mode in bimodal posteriors; below this the minor mode is discarded.
- bimodal_min_separation
Minimum normalised distance between modes to declare bimodality.
- run_id
Optional string identifier for this CMC run, used in checkpoint paths and log messages.
- min_points_for_cmc: int = 100000
- per_angle_mode: str = 'auto'
- constant_scaling_threshold: int = 3
- sharding_strategy: str = 'random'
- min_points_per_shard: int = 10000
- min_points_per_param: int = 1500
- backend_name: str = 'auto'
- enable_checkpoints: bool = True
- checkpoint_dir: str = './checkpoints/cmc'
- chain_method: str = 'parallel'
- num_warmup: int = 1500
- num_samples: int = 1500
- num_chains: int = 4
- target_accept_prob: float = 0.9
- max_tree_depth: int = 10
- seed: int = 42
- dense_mass: bool = True
- init_strategy: str = 'init_to_median'
- adaptive_sampling: bool = True
- min_warmup: int = 100
- min_samples: int = 200
- fast_warmup: bool = False
skip the dense-mass warmup floor (1500 steps). Intended for CI fast-mode and pytest fixtures only; NOT for production. When True, all warmup calculations bypass
effective_warmup_floor().- Type:
Rule 12 escape hatch
- max_r_hat: float = 1.1
- min_ess: int = 400
- min_bfmi: float = 0.3
- max_divergence_rate: float = 0.1
- require_nlsq_warmstart: bool = False
- allow_degenerate_warmstart: bool = False
- max_parameter_cv: float = 1.0
- heterogeneity_abort: bool = True
- use_nlsq_warmstart: bool = True
- use_nlsq_informed_priors: bool = True
- nlsq_prior_width_factor: float = 2.0
- use_log_space_priors: bool = True
integrate
build_log_space_priorsintobuild_default_priors. When True (default), parameters flaggedlog_space=Truein the parameter registry (currently D0_ref, D0_sample, v0) are sampled with LogNormal priors instead of TruncatedNormal — better mass-matrix conditioning for prefactors that span several orders of magnitude. When False, the registry’slog_spaceflag is ignored and all parameters use TruncatedNormal. The reparameterized path (use_reparam=True) is unaffected — it samples log_X_at_tref directly and does not reach this code path.- Type:
Codex S1
- prior_tempering: bool = True
- combination_method: str = 'robust_consensus_mc'
- min_success_rate: float = 0.9
- min_success_rate_warning: float = 0.8
- per_shard_timeout: int = 7200
- heartbeat_timeout: int = 600
- enable_jax_profiling: bool = False
- jax_profile_dir: str = './profiles/jax'
- use_reparam: bool = True
- reparameterization_d_total: bool = True
- reparameterization_log_gamma: bool = True
- bimodal_min_weight: float = 0.2
- bimodal_min_separation: float = 0.5
- validate()[source]
Run comprehensive field validation and return a list of error strings.
- is_valid()[source]
Return True if the configuration passes all validation checks.
Equivalent to
len(self.validate()) == 0.- Return type:
- should_enable_cmc(n_points, analysis_mode=None)[source]
Decide whether to run CMC given the dataset size.
- Parameters:
- Returns:
Trueif CMC should run for this dataset.- Return type:
- get_num_shards(n_points, n_phi, n_params=_N_PARAMS_HETERODYNE)[source]
Compute the number of shards K for a given dataset.
When
num_shardsis an explicit integer it is returned directly (clamped to >= 1). When"auto", K is derived as:Start from
max(n_phi, 2)— at least as many shards as phi angles.Apply the
min_points_per_shardlower bound:K <= n_points // min_points_per_shard.Apply the
min_points_per_paramconstraint:K <= n_points // (n_params * min_points_per_param).Apply the
max_points_per_shardupper bound when set:K >= ceil(n_points / max_points_per_shard).Clamp to
[1, n_points].
- get_adaptive_sample_counts(shard_size, n_params=_N_PARAMS_HETERODYNE)[source]
Scale warmup and sample counts for a given shard size.
When
adaptive_sampling=Falsethe configurednum_warmupandnum_samplesare returned unchanged.The scaling law is:
scale = clamp(shard_size / reference_size, 0, 1) warmup = max(min_warmup, round(num_warmup * scale)) samples = max(min_samples, round(num_samples * scale))
where
reference_size = _REFERENCE_SHARD_SIZE(10 000 points) is the shard size at which the full configured counts are used. Larger shards are not scaled up beyond the configured maximum; the formula saturates atscale = 1.A secondary check ensures a minimum of
n_paramssamples are drawn (ESS cannot exceednum_samples * num_chains).
- get_effective_per_angle_mode(n_phi, nlsq_per_angle_mode=None, has_nlsq_warmstart=False)[source]
Resolve the effective per-angle mode for a concrete dataset.
Mirrors
homodyne/optimization/cmc/config.py::get_effective_per_angle_modeso CMC/NLSQ parameterization stays in lock-step across both packages.Resolution logic (priority order):
If
nlsq_per_angle_modeis provided, mirror it for CMC↔NLSQ parameterization parity, regardless ofself.per_angle_mode. If both sides are"auto"ANDhas_nlsq_warmstartis True, promote to"constant_averaged"so scaling is fixed (fewer sampled params, less heterogeneity across shards).Else if
self.per_angle_mode != "auto"→ return it directly.Else (auto, no NLSQ):
n_phi >= constant_scaling_threshold→"auto"(sampled averaged); otherwise →"individual".
- Parameters:
n_phi (
int) – Number of distinct phi (azimuthal angle) bins in the dataset.nlsq_per_angle_mode (
str|None) – The per-angle mode resolved by the preceding NLSQ fit, if any. When provided this overrides the configured mode for parity.has_nlsq_warmstart (
bool) – Whether a valid NLSQ warm-start is available for this run.
- Returns:
Effective mode:
"auto","constant","constant_averaged", or"individual".- Return type:
- to_dict()[source]
Serialise the configuration to a nested dictionary.
The returned structure uses the same section names expected by
from_dict, making round-trips lossless.
- classmethod from_dict(config_dict)[source]
Construct a CMCConfig from a (possibly nested) dictionary.
Recognised top-level keys and sections:
enable,min_points_for_cmc,run_id— top-level scalars.prior_tempering— top-level scalar.per_angle— maps toper_angle_mode,constant_scaling_threshold.sharding— maps to the five sharding fields.backend_config— maps to the four backend fields.per_shard_mcmc— maps to the eleven sampling fields.validation— maps to the seven validation-threshold fields.nlsq— maps to the three NLSQ-prior fields.combination— maps to the three combination fields.timeout— maps to the two timeout fields.reparameterization— maps to the five reparam fields.
Flat (non-nested) dictionaries are also accepted for backward compatibility: any key that matches a field name directly is used as-is.
Unrecognised top-level keys emit a
warnings.warnso that configuration typos surface immediately.
- __init__(enable='auto', min_points_for_cmc=100000, per_angle_mode='auto', constant_scaling_threshold=3, sharding_strategy='random', num_shards='auto', max_points_per_shard='auto', min_points_per_shard=10000, min_points_per_param=1500, backend_name='auto', enable_checkpoints=True, checkpoint_dir='./checkpoints/cmc', chain_method='parallel', num_warmup=1500, num_samples=1500, num_chains=4, target_accept_prob=0.9, max_tree_depth=10, seed=42, dense_mass=True, init_strategy='init_to_median', adaptive_sampling=True, min_warmup=100, min_samples=200, fast_warmup=False, max_r_hat=1.1, min_ess=400, min_bfmi=0.3, max_divergence_rate=0.1, require_nlsq_warmstart=False, allow_degenerate_warmstart=False, max_parameter_cv=1.0, heterogeneity_abort=True, use_nlsq_warmstart=True, use_nlsq_informed_priors=True, nlsq_prior_width_factor=2.0, use_log_space_priors=True, prior_tempering=True, combination_method='robust_consensus_mc', min_success_rate=0.9, min_success_rate_warning=0.8, per_shard_timeout=7200, heartbeat_timeout=600, enable_jax_profiling=False, jax_profile_dir='./profiles/jax', use_reparam=True, reparameterization_d_total=True, reparameterization_log_gamma=True, bimodal_min_weight=0.2, bimodal_min_separation=0.5, run_id=None, _validation_errors=<factory>)
Note
Key attribute names (renamed from legacy):
target_accept_prob, max_r_hat, nlsq_prior_width_factor.
The from_dict() class method handles legacy key translation.
Results¶
Result container for CMC analysis.
- class heterodyne.optimization.cmc.results.ParameterStats[source]
Bases:
dictHybrid mapping/sequence for posterior summaries.
Supports dict-style access by name (
ps["D0_ref"]) and integer-index access (ps[0]). Inheritsdictso existing.get()/inchecks continue to work unchanged.- __init__(ordered_names, values)[source]
- property as_array: ndarray
Ordered values as a numpy array.
- class heterodyne.optimization.cmc.results.CMCResult[source]
Bases:
objectResult of CMC (Consensus Monte Carlo) analysis.
Contains posterior samples, summaries, and convergence diagnostics.
- posterior_mean: ndarray
- posterior_std: ndarray
- convergence_passed: bool
- num_warmup: int = 0
- num_samples: int = 0
- num_chains: int = 0
- num_shards: int = 1
- divergences: int = 0
- convergence_status: str = 'not_converged'
- per_angle_mode: str = 'auto'
- property n_params: int
Number of parameters.
- property parameters: ndarray
Homodyne-parity alias for
posterior_mean.
- property uncertainties: ndarray
Homodyne-parity alias for
posterior_std.
- get_param_summary(name)[source]
Get summary statistics for a parameter.
- get_samples(name)[source]
Get posterior samples for a parameter.
- validate_convergence(r_hat_threshold=1.1, min_ess=100, min_bfmi=0.3)[source]
Validate convergence diagnostics.
- get_samples_array()[source]
Return samples as a 3-D array of shape (num_chains, num_samples, n_params).
Parameters with missing or None samples are filled with zeros. Flat 1-D sample arrays of shape
num_chains * num_samplesare reshaped to(num_chains, num_samples)automatically.- Return type:
- get_posterior_stats()[source]
Return per-parameter posterior statistics.
Returns a dict keyed by parameter name. Each value contains
mean,std,median,hdi_5%,hdi_95%,r_hat,ess_bulk,ess_tail.Parameters absent from
self.samplesare omitted.
- classmethod from_mcmc_samples(mcmc_samples, stats, analysis_mode='static', n_warmup=500, min_ess=None)[source]
Build a
CMCResultfrom raw MCMC samples (homodyne parity).Mirrors
homodyne.optimization.cmc.results.CMCResult.from_mcmc_samples. Duck-typed:mcmc_samplesmust expose.samples(dict[str, ndarray]),.param_names(list[str]),.n_chains(int),.n_samples(int);statsmust expose.num_divergent(int) and may expose.wall_time/.warmup_time.Diagnostics (R-hat, ESS) are not computed here — they require per-chain reshaping and ArviZ. Callers that need diagnostics should run
cmc_result_to_arviz()and overwrite.r_hat/.ess_*after construction.- Parameters:
mcmc_samples (
Any) – Object holding posterior draws; see duck-typed surface above.stats (
Any) – Sampling statistics object; see duck-typed surface above.analysis_mode (
str) – Stored on the result for plot / report consumers. Default"static"mirrors homodyne.n_warmup (
int) – Number of warmup draws (recorded on the result).min_ess (
float|None) – Accepted for homodyne parity; ignored here because diagnostics are not computed by this factory.
- Returns:
Populated result with
parameter_names,posterior_mean,posterior_std,samples, and basic sampling/divergence metadata.credible_intervalsis left empty; downstream consumers can populate it viaget_posterior_stats.- Return type:
CMCResult
- __init__(parameter_names, posterior_mean, posterior_std, credible_intervals, convergence_passed, r_hat=None, ess_bulk=None, ess_tail=None, bfmi=None, samples=None, map_estimate=None, num_warmup=0, num_samples=0, num_chains=0, num_shards=1, divergences=0, wall_time_seconds=None, metadata=<factory>, convergence_status='not_converged', warmup_time=None, per_angle_mode='auto', chi_squared=None, quality_flag=None, mean_contrast=None, std_contrast=None, mean_offset=None, std_offset=None, inference_data=None)
- heterodyne.optimization.cmc.results.cmc_result_to_arviz(result)[source]
Convert a CMCResult to an ArviZ InferenceData object.
Samples stored in
result.samplesare reshaped to(num_chains, num_draws)whenresult.num_chains > 1so that ArviZ can compute per-chain diagnostics (R-hat, ESS). When the result carries flat 1-D arrays the function treats the entire sequence as a single chain.- Parameters:
result (
CMCResult) – Completed CMC analysis result.- Return type:
- Returns:
arviz.InferenceDatawith aposteriorgroup populated fromresult.samplesand, when available,sample_statspopulated fromresult.bfmi.- Raises:
ImportError – If ArviZ is not installed.
ValueError – If
result.samplesis None or empty.
- heterodyne.optimization.cmc.results.compare_cmc_nlsq(cmc_result, nlsq_result, consistency_sigma=2.0)[source]
Compare CMC posterior means with NLSQ point estimates.
Parameters that appear in both results are compared. Parameters present in only one result are silently skipped.
- Parameters:
cmc_result (
CMCResult) – Completed CMC result.nlsq_result (
Any) – Completed NLSQ result (NLSQResultinstance).consistency_sigma (
float) – Number of posterior standard deviations within which the NLSQ estimate must fall to be flagged as consistent. Defaults to 2.0 (approximately 95 % credible interval).
- Returns:
"common_parameters"— list of parameter names present in both."differences"— dict mapping name to(cmc_mean - nlsq_value)."relative_deviations"— dict mapping name toabs(cmc_mean - nlsq_value) / cmc_std."consistent"— dict mapping name to bool (True if withinconsistency_sigmaposterior std of the CMC mean)."n_consistent"— int count of consistent parameters."n_inconsistent"— int count of inconsistent parameters."consistency_sigma"— the threshold used.
- Return type:
- heterodyne.optimization.cmc.results.merge_shard_cmc_results(shard_results, parameter_names=None)[source]
Combine multiple shard CMCResults into a single consensus result.
Uses inverse-variance weighting (precision weighting) to combine posterior means from independent shards, following the Consensus Monte Carlo methodology (Scott et al., 2016). Diagnostics (R-hat, ESS, BFMI) are set to their worst-case values across shards so that failures are never hidden by averaging.
- Parameters:
shard_results (
list[CMCResult]) – Non-empty list of per-shard CMCResults. Each must have the sameparameter_names(orparameter_namesoverride must be supplied).parameter_names (
list[str] |None) – Optional explicit parameter name list. When supplied, only these parameters are included in the merged result; they must be present in every shard.
- Return type:
CMCResult- Returns:
A new
CMCResultrepresenting the consensus posterior.- Raises:
ValueError – If
shard_resultsis empty or parameter names are inconsistent across shards when no override is given.
- heterodyne.optimization.cmc.results.cmc_result_summary_table(result, ci_level='95', width=80)[source]
Format a CMCResult as a human-readable parameter summary table.
The table includes columns for parameter name, posterior mean, posterior standard deviation, credible interval bounds, R-hat, and bulk ESS. Missing diagnostics are shown as
N/A.- Parameters:
- Return type:
- Returns:
Multi-line string containing the formatted table.
- Raises:
ValueError – If
ci_levelis not"95"or"89".
NumPyro Model¶
NumPyro model definition for heterodyne Bayesian inference.
- heterodyne.optimization.cmc.model.xpcs_model_heterodyne_scaled(data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, space, num_shards=1)[source]
Joint multi-phi heterodyne CMC model (homodyne parity).
Mirrors
homodyne.optimization.cmc.model.xpcs_model_scaled: ONE NUTS pass over pooled multi-phi data with shared 14 physics parameters and per-angle sampled contrast / offset. The likelihood site evaluates the Normal log-prob at every pooled point in a singlenumpyro.samplecall, gather-by-phi-index.- Parameters:
data (
Array) – Pooled C2 values, shape(n_total,)after diagonal filtering.t (
Array) – Unique time grid, shape(N,). Used bycompute_c2_heterodyne.q (
float) – Physics scalars.dt (
float) – Physics scalars.phi_unique (
Array) – Sorted unique phi angles, shape(n_phi,).phi_indices (
Array) – Per-point index intophi_unique, shape(n_total,).i1_indices (
Array) – Per-point indices intotfor the two time coordinates, shape(n_total,)each. Pre-computed vianp.searchsorted(t, t1)/np.searchsorted(t, t2).i2_indices (
Array) – Per-point indices intotfor the two time coordinates, shape(n_total,)each. Pre-computed vianp.searchsorted(t, t1)/np.searchsorted(t, t2).noise_scale (
float) – Data-driven sigma prior centre (homodyne-parityHalfNormalscale =noise_scale * 1.5 * sqrt(num_shards)).space (
ParameterSpace) – Parameter space holding priors and initial values for the 14 physics parameters + 2 scaling.num_shards (
int) – Shard count for CMC sigma-prior tempering (Scott et al. 2016). Default1(no tempering).
- Return type:
- heterodyne.optimization.cmc.model.xpcs_model_heterodyne_constant(data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, space, fixed_contrast, fixed_offset, num_shards=1)[source]
Joint multi-phi CMC model with FIXED per-angle scaling.
Mirrors
homodyne.optimization.cmc.model.xpcs_model_constant. Per-anglecontrastandoffsetare passed in as arrays (lengthn_phi, typically derived from quantile estimation on the raw data) and NOT sampled — only the 14 physics params + sigma are sampled.- Return type:
- heterodyne.optimization.cmc.model.xpcs_model_heterodyne_averaged(data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, space, num_shards=1)[source]
Joint multi-phi CMC model with SAMPLED averaged (single) scaling.
Mirrors
homodyne.optimization.cmc.model.xpcs_model_averaged. A singlecontrastand a singleoffsetare sampled and broadcast across alln_phiangles (cf. heterodyneper_angle_mode="auto"when the auto-resolver promotes to averaged scaling).- Return type:
- heterodyne.optimization.cmc.model.xpcs_model_heterodyne_constant_averaged(data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, space, fixed_contrast, fixed_offset, num_shards=1)[source]
Joint multi-phi CMC model with FIXED averaged (single) scaling.
Mirrors
homodyne.optimization.cmc.model.xpcs_model_constant_averaged. A singlecontrastandoffset(typically the mean of the NLSQ per-angle estimates) are broadcast across alln_phiangles. No scaling parameters are sampled — only the 14 physics params + sigma.- Return type:
- heterodyne.optimization.cmc.model.get_heterodyne_pooled_model_for_mode(per_angle_mode, *, data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, space, fixed_contrast=None, fixed_offset=None, num_shards=1)[source]
Dispatch to the joint multi-phi model variant for
per_angle_mode.Mirrors
homodyne.optimization.cmc.model.get_xpcs_modelat the pooled-data layer. Returns a zero-arg callable suitable for passing toNUTS.Modes:
"individual"/"scaled"→xpcs_model_heterodyne_scaled()(per-angle sampled contrast/offset)."constant"→xpcs_model_heterodyne_constant(). Requiresfixed_contrastandfixed_offsetas length-n_phiarrays."auto"/"averaged"→xpcs_model_heterodyne_averaged()(single sampled averaged contrast/offset)."constant_averaged"→xpcs_model_heterodyne_constant_averaged(). Requires scalarfixed_contrastandfixed_offset.
- heterodyne.optimization.cmc.model.get_heterodyne_model(t, q, dt, phi_angle, c2_data, noise_scale, space, contrast=1.0, offset=1.0, shard_grid=None, priors_override=None, num_shards=1)[source]
Create NumPyro model for heterodyne correlation fitting.
Sigma is sampled as a posterior variable via
HalfNormal(noise_scale * 1.5 * sqrt(num_shards)), matching the homodyne parity convention so the posterior captures noise uncertainty.- Parameters:
t (
Array) – Time arrayq (
float) – Wavevectordt (
float) – Time stepphi_angle (
float) – Detector phi anglec2_data (
Array) – Observed correlation data — shape(N, N)for meshgrid path, or(n_pairs,)for element-wise path.noise_scale (
float) – Data-driven prior center for the measurement-uncertaintysigmaposterior. Typically the mean / RMS of an external estimate fromestimate_sigma().space (
ParameterSpace) – Parameter space with priorscontrast (
float) – Speckle contrast (beta), default 1.0offset (
float) – Baseline offset, default 1.0shard_grid (
ShardGrid|None) – Optional pre-computed ShardGrid. When provided, uses the memory-efficient element-wise path (no N×N allocation).c2_datamust then be flattened to match the shard grid’s paired indices.priors_override (
dict|None) – Optional dictionary mapping parameter names to NumPyro distributions. When provided, overrides the defaultspace.priors[name]for any matching parameter name. Used byfit_cmc_shardedto inject tempered priors.num_shards (
int) – Number of CMC shards for sigma prior tempering. Widens theHalfNormalscale bysqrt(num_shards)so that the product across shards stays equivalent to the unsharded prior. Defaults to1(no tempering).
- Returns:
NumPyro model function
- heterodyne.optimization.cmc.model.get_heterodyne_model_reparam(t, q, dt, phi_angle, c2_data, noise_scale, space, nlsq_params=None, reparam_config=None, scalings=None, contrast=1.0, offset=1.0, shard_grid=None, num_shards=1)[source]
Create NumPyro model with reparameterization for better sampling.
When NLSQ result and reparameterization config are provided, uses: 1. Reference-time reparameterization for power-law pairs 2. Smooth bounded transforms (tanh) instead of jnp.clip() 3. NLSQ-informed priors with delta-method uncertainty propagation
Falls back to the original clip-based behavior when the new infrastructure is not provided (backward compatibility).
Sigma is sampled internally via
HalfNormal(noise_scale * 1.5 * sqrt(num_shards))to match homodyne CMC parity.- Parameters:
t (
Array) – Time arrayq (
float) – Wavevectordt (
float) – Time stepphi_angle (
float) – Detector phi anglec2_data (
Array) – Observed correlation datanoise_scale (
float) – Data-driven prior center for sampledsigma.space (
ParameterSpace) – Parameter spacenlsq_params (
Array|None) – Optional NLSQ fitted values for centering (legacy path)reparam_config (
ReparamConfig|None) – Reparameterization config (enables new path)scalings (
dict[str,ParameterScaling] |None) – Pre-computed ParameterScaling per reparam-space paramnum_shards (
int) – CMC shard count for sigma prior tempering. Default1.
- Returns:
NumPyro model function
- heterodyne.optimization.cmc.model.get_heterodyne_model_constant(t, q, dt, phi_angle, c2_data, noise_scale, space, fixed_contrast, fixed_offset, shard_grid=None, num_shards=1)[source]
Create NumPyro model with FIXED (pre-computed) per-angle scaling.
Contrast and offset are not sampled — they are provided as fixed arrays from a preceding NLSQ or preprocessing step. Suitable for
per_angle_mode="constant", where each angle has its own fixed scaling but the physical parameters are shared.Sigma is sampled internally via
HalfNormal(noise_scale * 1.5 * sqrt(num_shards))for homodyne CMC parity.- Parameters:
t (
Array) – Time array, shape(n_t,).q (
float) – Wavevector magnitude (Å⁻¹).dt (
float) – Lag-time step (s).phi_angle (
float) – Detector phi angle for this shard (degrees).c2_data (
Array) – Observed correlation data, shape(n_t,)or(n_phi, n_t).noise_scale (
float) – Data-driven prior center for sampledsigma.space (
ParameterSpace) – Parameter space carrying priors and fixed values.fixed_contrast (
Array) – Speckle contrast per angle, shape(n_phi,)or scalar.fixed_offset (
Array) – Baseline offset per angle, shape(n_phi,)or scalar.num_shards (
int) – CMC shard count for sigma prior tempering. Default1.
- Returns:
NumPyro model callable (no required arguments).
- heterodyne.optimization.cmc.model.get_heterodyne_model_constant_averaged(t, q, dt, phi_angle, c2_data, noise_scale, space, mean_contrast, mean_offset, shard_grid=None, num_shards=1)[source]
Create NumPyro model with a single averaged scaling broadcast to all angles.
Both
mean_contrastandmean_offsetare scalars computed from the average over all phi angles. They are treated as fixed (not sampled) and broadcast uniformly. Suitable forper_angle_mode="constant_averaged".Sigma is sampled internally via
HalfNormal(noise_scale * 1.5 * sqrt(num_shards))for homodyne CMC parity.- Parameters:
t (
Array) – Time array, shape(n_t,).q (
float) – Wavevector magnitude (Å⁻¹).dt (
float) – Lag-time step (s).phi_angle (
float) – Detector phi angle for this shard (degrees).c2_data (
Array) – Observed correlation data.noise_scale (
float) – Data-driven prior center for sampledsigma.space (
ParameterSpace) – Parameter space carrying priors and fixed values.mean_contrast (
float) – Scalar speckle contrast averaged over all phi angles.mean_offset (
float) – Scalar baseline offset averaged over all phi angles.num_shards (
int) – CMC shard count for sigma prior tempering. Default1.
- Returns:
NumPyro model callable (no required arguments).
- heterodyne.optimization.cmc.model.get_heterodyne_model_individual(t, q, dt, phi_angles, c2_data, noise_scale, space, contrast_prior_loc=0.5, contrast_prior_scale=0.25, offset_prior_loc=1.0, offset_prior_scale=0.25, shard_grids=None, num_shards=1)[source]
Create NumPyro model with per-angle sampled contrast and offset.
The most general per-angle model: independently samples
contrast_iandoffset_ifor each phi angle using weakly informative Gaussian priors. Suitable forper_angle_mode="individual".Physical parameters are shared across all angles; the per-angle scaling lives in a
numpyro.plateover the angle dimension.- Parameters:
t (
Array) – Time array, shape(n_t,).q (
float) – Wavevector magnitude (Å⁻¹).dt (
float) – Lag-time step (s).phi_angles (
Array) – Detector phi angles, shape(n_phi,).c2_data (
Array) – Observed correlation data, shape(n_phi, n_t).sigma – Measurement uncertainty — scalar or shape
(n_phi, n_t).space (
ParameterSpace) – Parameter space carrying priors and fixed values.contrast_prior_loc (
Array|float) – Prior centre(s) for contrast. Scalar or(n_phi,)array. Default0.5.contrast_prior_scale (
float) – Prior width for contrast. Default0.25.offset_prior_loc (
Array|float) – Prior centre(s) for offset. Scalar or(n_phi,)array. Default1.0.offset_prior_scale (
float) – Prior width for offset. Default0.25.shard_grids (
list[ShardGrid] |None) – Optional list of pre-computed ShardGrids, one per phi angle. When provided, uses the memory-efficient element-wise path (no N×N allocation per angle).c2_data[ai]andsigma[ai]must then be flattened to match each shard grid’s paired indices. Without this, the model builds n_phi N×N matrices per NUTS step which can cause OOM for large datasets.
- Returns:
NumPyro model callable (no required arguments).
- heterodyne.optimization.cmc.model.get_model_for_mode(per_angle_mode, t, q, dt, phi_angle, c2_data, noise_scale, space, nlsq_result=None, reparam_config=None, num_shards=1, **kwargs)[source]
Select and build the appropriate NumPyro model based on per-angle mode.
Factory that maps
per_angle_modestrings to concrete model constructors. Extra keyword arguments are forwarded to the selected constructor, allowing callers to pass mode-specific parameters (e.g.fixed_contrast,mean_contrast,phi_angles) without branching at the call site.Mapping¶
"auto"Delegates to
get_heterodyne_model()(sampled contrast/offset from the parameter space) orget_heterodyne_model_reparam()whenreparam_configis supplied."constant"Delegates to
get_heterodyne_model_constant(). Requiresfixed_contrastandfixed_offsetinkwargs."constant_averaged"Delegates to
get_heterodyne_model_constant_averaged(). Requiresmean_contrastandmean_offsetinkwargs."individual"Delegates to
get_heterodyne_model_individual(). Requiresphi_anglesandc2_datashaped(n_phi, n_t)inkwargs.
- type per_angle_mode:
- param per_angle_mode:
One of
"auto","constant","constant_averaged","individual".- type t:
- param t:
Time array.
- type q:
- param q:
Wavevector magnitude (Å⁻¹).
- type dt:
- param dt:
Lag-time step (s).
- type phi_angle:
- param phi_angle:
Scalar phi angle (used by non-individual modes).
- type c2_data:
- param c2_data:
Observed correlation data.
- type noise_scale:
- param noise_scale:
Data-driven prior centre for the sampled
sigmasite.- type space:
ParameterSpace- param space:
Parameter space.
- type nlsq_result:
NLSQResult|None- param nlsq_result:
Optional NLSQ result for warm-starting (used by
"auto"mode whenreparam_configis supplied).- type reparam_config:
ReparamConfig|None- param reparam_config:
Optional reparameterization config. When provided alongside
"auto"mode, activates the reparam model path.- type num_shards:
- param num_shards:
CMC shard count for sigma prior tempering. Default
1.- type **kwargs:
- param **kwargs:
Mode-specific keyword arguments forwarded verbatim.
- rtype:
- returns:
NumPyro model callable (no required arguments).
- raises ValueError:
If
per_angle_modeis not a recognised string.
- heterodyne.optimization.cmc.model.estimate_sigma(c2_data, method='diagonal', nlsq_result=None, n_bootstrap=200, bootstrap_seed=0)[source]
Estimate measurement uncertainty from data.
Supported methods:
"diagonal"– Uses the standard deviation of the diagonal ofc2_datarelative to its mean, floored at 1 % of the data’s overall scale. Fast and requires no additional information."constant"– Returns the overall standard deviation ofc2_dataas a scalar."local"– Computes a spatially smoothed local variance viascipy.ndimage.uniform_filter. Requires SciPy."residual"– Estimates sigma from the RMS of NLSQ residuals. Requiresnlsq_resultwith a non-Noneresidualsfield. Falls back to"diagonal"if residuals are unavailable."bootstrap"– Drawsn_bootstrapbootstrap replicates of the diagonal and returns the standard deviation of per-replicate means as the noise estimate. Useful when the diagonal has enough points to bootstrap.
- Parameters:
c2_data (
Array) – Correlation data, shape(n_t,)or(n_phi, n_t).method (
str) – Estimation method — one of"diagonal","constant","local","residual","bootstrap".nlsq_result (
NLSQResult|None) – NLSQ result object. Required (and used) only formethod="residual".n_bootstrap (
int) – Number of bootstrap replicates formethod="bootstrap". Default200.bootstrap_seed (
int) – JAX PRNG seed formethod="bootstrap". Default0.
- Return type:
- Returns:
Estimated sigma — same shape as
c2_datafor"local", scalar or(n_t,)array for all other methods.- Raises:
ValueError – If
methodis not a recognised string.
- heterodyne.optimization.cmc.model.validate_model_output(c2_theory, params)[source]
Validate that theoretical C2 values are physically reasonable.
Checks for NaN/inf values and enforces the heterodyne C2 range constraint
[-1.0, 10.0]. Heterodyne C2 can go negative due to the velocity phase term, unlike homodyne where C2 >= 0.
- heterodyne.optimization.cmc.model.get_model_param_count(n_phi, per_angle_mode='individual')[source]
Return total number of sampled parameters for the model.
Accounts for per-angle mode semantics when counting contrast/offset parameters that are sampled in addition to the shared physics parameters.
Per-angle mode contributions:
"constant"— 0 per-angle params (fixed contrast/offset)."constant_averaged"— 0 per-angle params (fixed averaged contrast/offset)."auto"— physics params only (contrast/offset live in the parameter space, already counted)."individual"—2 * n_phiper-angle params (contrast_z+offset_zper angle).
- Parameters:
- Return type:
- Returns:
Total number of sampled parameters (int).
- Raises:
ValueError – If
per_angle_modeis not recognised.
Priors¶
NLSQ-informed prior construction for heterodyne CMC analysis.
Builds NumPyro distribution dictionaries from NLSQ warm-start results
or from the parameter registry defaults, including log-space priors
for parameters flagged with log_space=True.
- heterodyne.optimization.cmc.priors.build_nlsq_informed_priors(nlsq_result, param_space, width_factor=2.0)[source]
Build priors centered on NLSQ point estimates.
For each varying parameter, constructs a truncated Normal prior centered on the NLSQ best-fit value with width equal to the NLSQ uncertainty multiplied by
width_factor. When NLSQ uncertainty is unavailable, falls back to registryprior_stdor a fraction of the parameter range.- Parameters:
nlsq_result (
NLSQResult) – Converged NLSQ result with parameter values and (optionally) uncertainties.param_space (
ParameterSpace) – Parameter space defining which parameters vary and their physical bounds.width_factor (
float) – Multiplier on NLSQ uncertainty to set prior width. Larger values give more diffuse priors. Default 2.0 gives a prior that spans roughly 4 sigma around the NLSQ estimate.
- Return type:
- Returns:
Dictionary mapping parameter names to NumPyro distributions. Only includes parameters in
param_space.varying_names.
- heterodyne.optimization.cmc.priors.build_default_priors(param_space, registry=None, use_log_space_priors=True)[source]
Build default priors from the parameter registry.
Uses
prior_meanandprior_stdfrom each parameter’sParameterInfo. All bounded parameters use TruncatedNormal so thattemper_priorscan scale them bysqrt(K)for Consensus Monte Carlo sharding.Codex S1: when
use_log_space_priors=True(default), parameters flaggedlog_space=Truein the registry (D0_ref, D0_sample, v0) are overridden with LogNormal priors viabuild_log_space_priors(). LogNormal mass-matrix conditioning is much better for prefactors that span several orders of magnitude. Passuse_log_space_priors=Falseto fall back to TruncatedNormal uniformly.The reparameterized path (
CMCConfig.use_reparam=True) is unaffected: it samples log_X_at_tref directly and never calls this function for those parameters.- Parameters:
param_space (
ParameterSpace) – Parameter space defining which parameters vary and their physical bounds.registry (
ParameterRegistry|None) – Parameter registry to read metadata from. Defaults toDEFAULT_REGISTRY.use_log_space_priors (
bool) – Apply log-space priors tolog_space=Trueparameters from the registry. DefaultTrue.
- Return type:
- Returns:
Dictionary mapping parameter names to NumPyro distributions. Only includes parameters in
param_space.varying_names.
- heterodyne.optimization.cmc.priors.build_log_space_priors(param_names, registry=None)[source]
Build log-normal priors for parameters marked
log_space=True.For parameters where the registry’s
log_spaceflag is set, this constructs a LogNormal distribution whose median matches the registryprior_mean(or the parameter default) and whose spread corresponds to the registryprior_std.Parameters not flagged as
log_spaceare silently skipped.- Parameters:
- Return type:
- Returns:
Dictionary mapping parameter names to LogNormal distributions. Only includes parameters where
log_space=True.
- heterodyne.optimization.cmc.priors.temper_priors(priors, num_shards)[source]
Scale prior widths for Consensus MC shard sub-posteriors.
Each shard sees 1/K of the data, so the prior should be tempered by
sqrt(K)to maintain proper posterior geometry when K sub-posteriors are combined via the consensus step.Supported distribution types and their tempering rules:
TruncatedNormal— scale multiplied bysqrt(K).LogNormal— scale multiplied bysqrt(K).Uniform— left unchanged (uninformative; no tempering needed).All others — kept unchanged with a warning logged.
- Parameters:
priors (
dict[str,Distribution]) – Dict of NumPyro distributions, one per varying parameter.num_shards (
int) – Number of CMC shards (K). Must be >= 1.
- Return type:
- Returns:
New dict with tempered distributions. Existing dict is not mutated.
- heterodyne.optimization.cmc.priors.validate_priors(priors, param_space)[source]
Validate prior distributions against parameter space.
Checks:
All varying parameters have a corresponding prior.
Prior support overlaps with the parameter bounds (non-empty intersection).
No degenerate (effectively zero-width) priors.
A prior is considered degenerate when its extractable scale is below
1e-12. Uniform priors are never degenerate.- Parameters:
priors (
dict[str,Distribution]) – Dict of NumPyro distributions.param_space (
ParameterSpace) – Defines varying parameter names and their bounds.
- Return type:
- Returns:
List of warning/error strings. Empty list means all checks passed.
- heterodyne.optimization.cmc.priors.summarize_priors(priors)[source]
Format a human-readable summary of prior distributions.
For each prior, reports the distribution type and, where applicable, the mean, standard deviation, and support interval.
- Parameters:
priors (
dict[str,Distribution]) – Dict of NumPyro distributions.- Return type:
- Returns:
Multi-line string with one row per parameter.
- heterodyne.optimization.cmc.priors.get_param_names_in_order(vary_flags=None)[source]
Return the ordered list of parameter names that are set to vary.
Iteration order matches the registry’s insertion order, which follows the canonical group ordering defined in
parameter_names.py(reference → sample → velocity → fraction → angle → scaling).- Parameters:
vary_flags (
dict[str,bool] |None) – Optional override dict mapping parameter name to a bool indicating whether that parameter varies. Parameters absent fromvary_flagsfall back to the registry’svary_defaultattribute. PassNoneto use registry defaults for all parameters.- Return type:
- Returns:
List of parameter names for which the effective
varyflag isTrue, in registry order.
- heterodyne.optimization.cmc.priors.validate_initial_value_bounds(init_values, param_specs=None)[source]
Check that each initial value lies within the parameter’s bounds.
- Parameters:
- Return type:
- Returns:
Mapping from parameter name to a list of warning strings. An empty dict indicates all values are within bounds.
- heterodyne.optimization.cmc.priors.build_init_values_dict(nlsq_values=None, vary_flags=None, fallback='prior_mean')[source]
Build an initial-values dict for NUTS warm-starting.
For each varying parameter the value is resolved in order:
NLSQ estimate from
nlsq_values(if available).Registry
prior_meanwhenfallback="prior_mean"andprior_meanis notNone.Registry
defaultvalue.
All resolved values are validated against bounds and clamped when necessary, with a logged warning per clamped parameter.
- Parameters:
nlsq_values (
dict[str,float] |None) – Optional NLSQ MAP estimates keyed by parameter name.vary_flags (
dict[str,bool] |None) – Optional dict controlling which parameters vary (seeget_param_names_in_order()).fallback (
str) – Strategy for parameters absent fromnlsq_values."prior_mean"uses the registry prior mean (default);"default"uses the registry default value.
- Return type:
- Returns:
Dict mapping each varying parameter name to its initial value, ready to pass to
run_with_init_values().
- heterodyne.optimization.cmc.priors.extract_nlsq_values_for_cmc(nlsq_result)[source]
Extract parameter values and uncertainties from an NLSQ result.
Converts the array-based
NLSQResultinto plainfloatdictionaries suitable for CMC warm-starting. NaN and inf values are filtered out so that downstream prior construction never receives non-finite inputs.- Parameters:
nlsq_result (
NLSQResult) – Converged NLSQ result with.parameters,.parameter_names, and optionally.uncertainties.- Returns:
valuesmaps parameter name to its fitted float value (non-finite entries excluded).uncertaintiesmaps parameter name to its float uncertainty, orNonewhen the NLSQ result carries no uncertainty information. Non-finite entries are excluded.
- Return type:
- heterodyne.optimization.cmc.priors.estimate_per_angle_scaling(data_dict, angle_keys=None)[source]
Estimate contrast and offset scaling per scattering angle.
Uses simple heuristics on the raw g2 correlation data to provide starting-point estimates for the
contrastandoffsetscaling parameters:contrastestimate ≈max(g2) - min(g2)over the full lag range.offsetestimate ≈ mean of the last 10 % of g2 values (long-lag baseline), clamped to[0, 1].
These are heuristics suitable for warm-starting, not MAP estimates. The NLSQ/MCMC optimisation will refine them.
- Parameters:
- Return type:
- Returns:
Mapping of
angle_key -> (contrast_estimate, offset_estimate). Keys for which data could not be parsed are silently omitted.
- heterodyne.optimization.cmc.priors.estimate_contrast_offset_from_data(c2_data, t1, t2, contrast_bounds=(0.0, 1.0), offset_bounds=(0.5, 1.5), lag_floor_quantile=0.80, lag_ceiling_quantile=0.20, value_quantile_low=0.10, value_quantile_high=0.90)[source]
Estimate contrast and offset from C2 data via physics-informed quantile analysis.
Uses the correlation decay: C2 = contrast × g1² + offset. At large lags g1² → 0 so C2 → offset; at small lags g1² ≈ 1 so C2 ≈ contrast + offset.
- heterodyne.optimization.cmc.priors.validate_init_values_order(init_values, expected_names)[source]
Validate that init-values key order matches the expected parameter order.
Homodyne CMC parity helper. Python 3.7+ dicts preserve insertion order, so positional consumers of
init_values(e.g. when zipping with NLSQ arrays) depend on a stable iteration order. This check makes ordering bugs fail loudly with a descriptive error instead of producing silently wrong parameter bindings.- Parameters:
- Raises:
ValueError – When the key count or per-position order disagrees with
expected_names. The error names the first mismatching index and quotes the full lists for fast diagnosis.- Return type:
- heterodyne.optimization.cmc.priors.BOUNDARY_INTERIOR_MARGIN: float = 0.05
5% of the bound range (linear) or 5% of the log-range (geometric, for
log_space=Trueparameters). NUTS leapfrog step-size adaptation collapses if the chain initialises at a TruncatedNormal boundary; this margin keeps the warm-start away from the reflecting wall.- Type:
Default boundary-clamp margin
- heterodyne.optimization.cmc.priors.clamp_params_to_interior(params, parameter_names, *, margin=BOUNDARY_INTERIOR_MARGIN)[source]
Shift parameter values inward from hard bounds (raw-array path).
Lower-level companion to
clamp_to_interior()(the NLSQResult entry point). Named separately so the two public APIs don’t share a symbol — callers without anNLSQResult(e.g. direct Python users passing raw arrays tofit_cmc_jax()) use this entry.Linear-scale parameters are clamped to
[lo + margin * range, hi - margin * range]whererange = hi - lo. Log-space parameters (D0_ref, D0_sample, v0 — registrylog_space=True, positivemin_bound) use a geometric margin so a 5% linear fraction of a multi-decade range does not produce a 500× clamp target.- Parameters:
- Return type:
- Returns:
(new_params, clamped_names)— a fresh array with values shifted inward and the list of names that were actually moved.
- heterodyne.optimization.cmc.priors.clamp_to_interior(result, fixed_param_overrides=None, *, margin=BOUNDARY_INTERIOR_MARGIN)[source]
Return a copy of result with parameters shifted inward from hard bounds.
NUTS step-size collapses when the chain initialises at a TruncatedNormal boundary. Linear-scale parameters are clamped to
[min_bound + margin, max_bound - margin]where margin is 5% of the bound range. Log-space parameters (D0, v0) use a geometric margin so a linear fraction of a multi-decade range does not produce an absurd clamp target. Both ensure the leapfrog step-size adaptation starts well away from the reflecting wall.fixed_param_overridesmaps parameter names to values from the current model config’sfixed_parameters. When provided, any NLSQ result value for a fixed parameter is replaced with the config value before bounds clamping. This prevents a stalenlsq_data.npz(fitted in a prior run where the parameter was free) from propagating a superseded value into CMC initialisation, which can place the warm-start outside the reparameterised prior support and cause log-prior = −∞ → BFMI = 0 across all shards.- Parameters:
- Return type:
NLSQResult- Returns:
A new
NLSQResultwith parameters clamped if any were out of the safe interior; the originalresultitself when no clamping was needed.
- class heterodyne.optimization.cmc.priors.PriorBuilder[source]
Bases:
objectConstruct NumPyro priors from a parameter registry.
- Parameters:
registry (
ParameterRegistry|None) – Parameter registry (defaults toheterodyne.config.parameter_registry.DEFAULT_REGISTRY).use_log_space_priors (
bool) – When True (default), parameters flaggedlog_space=Truein the registry are returned as LogNormal distributions instead of TruncatedNormal — mass-matrix conditioning for multi-decade prefactors (D0_ref, D0_sample, v0). See codex S1.
- Raises:
RuntimeError – When the registry and
parameter_space._DEFAULT_PRIOR_SPECSdisagree on(prior_mean, prior_std)vs(loc, scale). This is Rule 9 fromCLAUDE.md— the dual-prior system must stay in sync. Tolerance isrel_tol=1e-6, abs_tol=1e-9.
- __init__(registry=None, use_log_space_priors=True)[source]
- build(param_space)[source]
Build priors for the varying parameters in param_space.
- Parameters:
param_space (
ParameterSpace) – ParameterSpace defining which parameters vary and their physical bounds.- Return type:
- Returns:
Dictionary mapping each varying parameter name to its NumPyro distribution. D0_ref/D0_sample/v0 are LogNormal when
use_log_space_priors=True, TruncatedNormal-family otherwise; remaining parameters are always TruncatedNormal.
- heterodyne.optimization.cmc.priors.build_default_priors_via_builder(param_space, registry=None, use_log_space_priors=True)[source]
Functional shim around
PriorBuilderfor callers that want a one-shot factory. Constructing the builder runs the sync gate, so each call validates the dual-prior invariant.- Return type:
- heterodyne.optimization.cmc.priors.build_log_space_priors_via_builder(param_names, registry=None)[source]
Functional shim for the legacy log-space-only helper.
The construction-time gate runs here too — calling this entry point is enough to fail loud on a desynced registry/spec pair.
- Return type:
Sampler¶
High-level NUTS sampler wrapper for heterodyne CMC analysis.
Provides a SamplingPlan dataclass for sampling hyperparameters and a
NUTSSampler class that wraps NumPyro’s MCMC with ergonomic factories,
automatic chain initialization with perturbation, and ArviZ diagnostics.
- class heterodyne.optimization.cmc.sampler.SamplingPlan[source]
Bases:
objectHyperparameters for NUTS sampling.
Immutable configuration that fully specifies a sampling run.
- num_warmup
Number of warmup (adaptation) steps per chain.
- num_samples
Number of posterior draws per chain after warmup.
- num_chains
Number of independent MCMC chains.
- target_accept
Target acceptance probability for dual-averaging step-size adaptation. Values in [0.6, 0.95] are typical.
- max_tree_depth
Maximum binary tree depth for NUTS. Higher values allow longer trajectories but increase per-step cost.
- adapt_step_size
Whether to use dual-averaging step-size adaptation during warmup.
- dense_mass
Whether to estimate a dense (full) mass matrix during warmup, or use a diagonal approximation.
- seed
Explicit random seed for reproducibility. If
None, a cryptographically random seed is generated.
- num_warmup: int = 500
- num_samples: int = 1000
- num_chains: int = 4
- target_accept: float = 0.8
- max_tree_depth: int = 10
- adapt_step_size: bool = True
- dense_mass: bool = True
- chain_method: str = 'sequential'
- fast_warmup: bool = False
Rule 12 escape hatch propagated from
CMCConfig.fast_warmup. When True,for_shardandAdaptiveSamplingPlanskip the dense-mass warmup floor. CI / pytest fast-mode only — not for production posteriors.
- property effective_seed: int
Return the seed, generating one if not explicitly set.
- classmethod from_config(config, n_data=None, n_params=None)[source]
Build a
SamplingPlanfrom aCMCConfig.Applies adaptive scaling when
config.adaptive_samplingisTrueandn_datais provided: warmup and sample counts are scaled proportionally to the ration_data / _REFERENCE_SHARD_SIZEand clamped to the configured floors.- Parameters:
config (
CMCConfig) – CMC configuration carrying all NUTS hyperparameters and adaptive-sampling knobs.n_data (
int|None) – Number of data points in this shard (or the full dataset when sharding is disabled). WhenNone, no adaptive scaling is applied regardless ofconfig.adaptive_sampling.n_params (
int|None) – Number of varying model parameters. Reserved for future dimension-aware scaling; currently unused.
- Return type:
SamplingPlan- Returns:
Fully validated
SamplingPlan.
- for_shard(shard_size, full_size)[source]
Return a scaled-down plan appropriate for a single CMC shard.
Scales warmup and sample counts by
sqrt(shard_size / full_size)to reflect the reduced information content of the shard. Counts are clamped to a minimum ofmax(1, num_x // 10)to avoid degenerate one-step runs.- Parameters:
- Return type:
SamplingPlan- Returns:
New
SamplingPlanwith adjusted warmup/sample counts and the same seed and other hyperparameters.- Raises:
ValueError – If
shard_size <= 0orfull_size <= 0.
- __init__(num_warmup=500, num_samples=1000, num_chains=4, target_accept=0.8, max_tree_depth=10, adapt_step_size=True, dense_mass=True, chain_method='sequential', seed=None, fast_warmup=False)
- class heterodyne.optimization.cmc.sampler.NUTSSampler[source]
Bases:
objectHigh-level NUTS sampler wrapping NumPyro’s MCMC.
Manages kernel construction, chain initialization with perturbation, sampling execution, and ArviZ diagnostic extraction.
Use the
from_plan()factory for the standard construction path.- __init__(mcmc, plan)[source]
- property plan: SamplingPlan
The sampling plan used to configure this sampler.
- classmethod from_plan(plan, model, init_strategy='init_to_median', chain_method=None)[source]
Create a NUTSSampler from a SamplingPlan and NumPyro model.
- Parameters:
plan (
SamplingPlan) – Sampling hyperparameters.model (
Callable[...,Any]) – NumPyro model function (callable with no required args).init_strategy (
str) – NumPyro initialization strategy name. One of"init_to_median","init_to_sample","init_to_value".chain_method (
str|None) – NumPyro chain execution method."sequential"for single device,"parallel"for multi-device.
- Return type:
NUTSSampler- Returns:
Configured NUTSSampler ready for
run().
- run(rng_key=None, init_params=None)[source]
Run MCMC sampling.
If
init_paramsare provided, small random perturbations are added per chain to break symmetry and improve exploration.- Parameters:
- Return type:
- Returns:
Dictionary of posterior samples (ungrouped).
- Raises:
RuntimeError – If sampling fails.
- run_with_init_values(init_values, rng_key=None)[source]
Run MCMC seeded from NLSQ warm-start values.
Validates that the initial log density is finite before launching full sampling, raising early with a diagnostic message if not.
- Parameters:
- Return type:
- Returns:
Dictionary of posterior samples (ungrouped).
- Raises:
RuntimeError – If the initial log density is not finite or if sampling itself fails.
- get_divergence_stats()[source]
Extract divergence rate and tree-depth statistics from the last run.
Requires that
run()orrun_with_init_values()has been called.- Returns:
"divergence_rate"Fraction of post-warmup transitions that were divergent. Zero when no divergences were recorded.
"mean_tree_depth"Mean NUTS tree depth across all post-warmup samples and chains. Values near
max_tree_depthindicate the trajectory is being truncated."max_tree_depth_fraction"Fraction of samples that hit the maximum tree depth (
plan.max_tree_depth).
- Return type:
- Raises:
RuntimeError – If called before
run().
- get_diagnostics()[source]
Extract ArviZ InferenceData for convergence diagnostics.
- Return type:
InferenceData- Returns:
ArviZ InferenceData containing posterior samples, sample stats (energy, divergences), and warmup statistics.
- Raises:
RuntimeError – If called before
run().
- log_adapter_diagnostics(run_logger=None)[source]
Log NUTS adapter state at INFO level — homodyne CMC parity helper.
Reports the adapted
step_sizeper chain, a compact summary of the adapted inverse mass matrix, and per-stepaccept_prob/num_steps/potential_energystatistics when the corresponding extra fields are present.- Parameters:
run_logger (
Any|None) – Logger to emit through.Noneuses this module’s logger.- Raises:
RuntimeError – If called before
run().- Return type:
- property mcmc: MCMC
Access the underlying NumPyro MCMC object.
- class heterodyne.optimization.cmc.sampler.AdaptiveSamplingPlan[source]
Bases:
objectSampling plan that adjusts warmup/sample counts based on shard size.
Wraps a base
SamplingPlanand scales it down proportionally when the shard is smaller than a reference size, while respecting floors derived from parameter count.The scaling rule is:
scale = sqrt(shard_size / reference_shard_size) num_warmup = max(min_warmup_floor, int(base.num_warmup * scale)) num_samples = max(min_samples_floor, int(base.num_samples * scale))
where
min_warmup_floor = max(50, 5 * n_params)andmin_samples_floor = max(100, 10 * n_params).- base_plan
Base
SamplingPlanfor a full-size shard.
- shard_size
Number of data points in this shard.
- n_params
Number of varying model parameters. Used to set minimum sample-count floors.
- base_plan: SamplingPlan
- shard_size: int
- n_params: int
- get_plan()[source]
Return a
SamplingPlanadjusted for this shard.Scaling is sub-linear (square-root) so that small shards still receive enough samples to characterise the posterior, while large shards are not penalised by excessive warmup.
The floor on warmup is
max(50, 5 * n_params)— enough adaptation steps to approximate the mass matrix for a 14-parameter model (floor = 70 steps). The floor on samples ismax(100, 10 * n_params)— enough draws for basic ESS diagnostics.- Return type:
SamplingPlan- Returns:
Scaled
SamplingPlan.
- __init__(base_plan, shard_size, n_params, _reference_shard_size=10000)
- heterodyne.optimization.cmc.sampler.DIVERGENCE_RATE_TARGET: float = 0.01
Target divergence rate — below this level the run is considered healthy.
- heterodyne.optimization.cmc.sampler.DIVERGENCE_RATE_HIGH: float = 0.05
Elevated divergence rate — triggers a retry in
run_nuts_with_retry().
- heterodyne.optimization.cmc.sampler.DIVERGENCE_RATE_CRITICAL: float = 0.1
Critical divergence rate — posterior geometry is likely incompatible with HMC.
- class heterodyne.optimization.cmc.sampler.SamplingStats[source]
Bases:
objectSummary statistics from a completed NUTS sampling run.
- num_samples
Number of posterior draws per chain (post-warmup).
- num_warmup
Number of warmup steps per chain.
- num_divergences
Total divergent transitions across all chains.
- divergence_rate
Fraction of post-warmup transitions that diverged.
- mean_accept_prob
Mean Metropolis acceptance probability.
- max_tree_depth_fraction
Fraction of samples that hit the maximum NUTS tree depth.
- wall_time_seconds
Elapsed wall-clock time for the run.
- num_samples: int
- num_warmup: int
- num_divergences: int
- divergence_rate: float
- mean_accept_prob: float
- max_tree_depth_fraction: float
- wall_time_seconds: float
- property is_healthy: bool
Return True when divergence rate and acceptance probability are acceptable.
Criteria:
divergence_rate < 0.05(belowDIVERGENCE_RATE_HIGH)mean_accept_prob > 0.6
- __init__(num_samples, num_warmup, num_divergences, divergence_rate, mean_accept_prob, max_tree_depth_fraction, wall_time_seconds)
- heterodyne.optimization.cmc.sampler.run_nuts_with_retry(sampler, model_fn, model_kwargs, max_retries=3, target_accept_increment=0.05, *, step_size_factor=None)[source]
Run NUTS sampling with automatic step-size reduction on high divergence.
Executes
run()and checks the divergence rate after each attempt. When the rate exceedsDIVERGENCE_RATE_HIGH, a newNUTSSampleris built withtarget_acceptRAISED bytarget_accept_increment(which drives dual averaging toward a SMALLER step size — the mathematically correct response to high divergence rate) and the run is retried. Aftermax_retriesattempts the result with the lowest divergence rate is returned regardless of health. Mirrors homodynerun_nuts_with_retry(sampler.py:1311-1314).The
model_fnis re-used across retries so it must be stateless (i.e. a pure NumPyro model function with no side effects).- Parameters:
sampler (
NUTSSampler) – ConfiguredNUTSSamplerfor the first attempt.model_fn (
Any) – NumPyro model callable. Not called directly here but passed toNUTSSampler.from_plan()for retry instances.model_kwargs (
dict[str,Any]) – Keyword arguments forwarded to the model viarun(). Currently unused byrun()(which takesrng_keyandinit_params); included for forward compatibility.max_retries (
int) – Maximum number of additional attempts after the first run. Total runs =max_retries + 1.target_accept_increment (
float) – Additive increase applied totarget_accepteach retry (e.g. 0.80 → 0.85 → 0.90 → 0.95). Raising the target acceptance LOWERS the dual-averaging step size, which is the mathematically correct response to high divergence rates. Must be in(0, 0.5). The target is clamped at0.99to avoid pathological tiny step sizes.step_size_factor (
float|None) – DEPRECATED keyword-only alias. Earlier versions multipliedtarget_acceptby this factor (with factor< 1) on retry, which drove dual averaging toward a LARGER step size — the OPPOSITE of what high divergence rate calls for. When supplied,target_accept_incrementis derived as(1 - step_size_factor) * 0.1so legacy callers see a corrected mathematical direction.
- Return type:
- Returns:
Tuple of
(samples_dict, SamplingStats)for the best attempt (lowest divergence rate).
- heterodyne.optimization.cmc.sampler.compute_mcmc_safe_initial_values(initial_values, *, q, dt, time_grid, target_g1=0.5, g1_threshold=0.1)[source]
Detect/repair vanishing-g1 initial parameters for heterodyne NUTS.
Inspects both the reference and sample transport components and, when either drives
g1belowg1_thresholdat a typical lag, rescales its(D0, D_offset)pair so thatg1 ≈ target_g1. ReturnsNonewhen no adjustment is needed so callers can short-circuit.- Parameters:
initial_values (
dict[str,float] |None) – NLSQ warm-start dictionary. Expected keys includeD0_ref,alpha_ref,D_offset_ref,D0_sample,alpha_sample,D_offset_sample.q (
float) – Wavevector magnitude (Å⁻¹).dt (
float) – Lag-time step (s).target_g1 (
float) – Targetg1value when rescaling.g1_threshold (
float) – Threshold below which rescaling fires.
- Return type:
- Returns:
A new dictionary with adjusted
D0_*/D_offset_*values when adjustment is required, elseNone.
Diagnostics¶
Convergence diagnostics for CMC analysis.
- class heterodyne.optimization.cmc.diagnostics.ConvergenceReport[source]
Bases:
objectReport of convergence diagnostic checks.
- passed: bool
- r_hat_passed: bool
- ess_passed: bool
- bfmi_passed: bool
- __init__(passed, r_hat_passed, ess_passed, bfmi_passed, messages)
- heterodyne.optimization.cmc.diagnostics.validate_convergence(result, r_hat_threshold=1.1, min_ess=100, min_bfmi=0.3)[source]
Validate MCMC convergence from CMC result.
Checks: 1. R-hat (Gelman-Rubin statistic) < threshold for all parameters 2. Effective sample size (ESS) > minimum for all parameters 3. Bayesian Fraction of Missing Information (BFMI) > minimum
- heterodyne.optimization.cmc.diagnostics.compute_posterior_contraction(result, prior_std)[source]
Compute Posterior Contraction Ratio for each parameter.
PCR = 1 - posterior_std / prior_std
- Interpretation:
~1.0 = strongly constrained by data ~0.0 = poorly identified (prior dominates) <0 = possible model misspecification (posterior wider than prior)
- heterodyne.optimization.cmc.diagnostics.compute_r_hat(samples)[source]
Compute rank-normalized R-hat from samples.
Delegates to
arviz.rhat()which implements the recommended rank-normalized split-R-hat from Vehtari et al. (2021).
- heterodyne.optimization.cmc.diagnostics.compute_ess(samples)[source]
Compute effective sample size.
Delegates to
arviz.ess()which uses FFT-based autocorrelation with Geyer’s initial monotone sequence estimator.
- heterodyne.optimization.cmc.diagnostics.compute_bfmi(energy)[source]
Compute Bayesian Fraction of Missing Information.
Delegates to
arviz.bfmi(). Values < 0.3 indicate potential problems with HMC sampling.
- class heterodyne.optimization.cmc.diagnostics.DivergenceReport[source]
Bases:
objectDivergence rate analysis.
- divergence_rate: float
- n_divergent: int
- n_total: int
- severity: str
- __init__(divergence_rate, n_divergent, n_total, severity, messages)
- heterodyne.optimization.cmc.diagnostics.analyze_divergences(samples)[source]
Analyse divergent transitions from MCMC samples.
Accepts either a raw samples dict that may contain a
"diverging"field (shape(n_chains, n_draws)) or aCMCResultwhoseextra_fieldsattribute holds that field.Divergence rate is computed globally (all chains combined) and per-chain for contextual messages.
- heterodyne.optimization.cmc.diagnostics.validate_convergence_sharded(results, r_hat_threshold=1.1, min_ess=100, min_bfmi=0.3)[source]
Validate convergence across CMC shards.
Runs
validate_convergence()on each shard and returns a combinedConvergenceReportthat reflects the worst-case R-hat and the minimum ESS observed across all shards. A single failing shard causes the combined report to fail.- Parameters:
- Return type:
ConvergenceReport- Returns:
Combined ConvergenceReport with worst-case statistics.
- heterodyne.optimization.cmc.diagnostics.compute_trace_diagnostics(samples, lags=(1, 5, 10))[source]
Compute trace-level diagnostics for a single parameter’s samples.
- Parameters:
- Returns:
autocorrDict mapping each lag to its mean autocorrelation across chains.
stationaryTrueif the absolute autocorrelation at lag 1 is below 0.5 for all chains (heuristic stationarity flag).mixing_qualityOne of
"good","moderate", or"poor"based on lag-1 autocorrelation magnitude.n_chainsNumber of chains.
n_drawsNumber of draws per chain.
meanGrand mean across all draws.
stdGrand standard deviation across all draws.
- Return type:
- heterodyne.optimization.cmc.diagnostics.compute_pair_correlations(samples)[source]
Compute pairwise Pearson correlations between parameters.
Flattens all chains and draws for each parameter before computing correlations, so the result is chain-agnostic.
Useful for detecting parameter degeneracy: correlations with
|r| > 0.9indicate near-redundant parameters.- Parameters:
samples (
dict[str,ndarray]) – Mapping from parameter name to array of shape(n_chains, n_draws)or(n_draws,).- Return type:
- Returns:
Nested dict
corr[param_a][param_b] = rwhereris the Pearson correlation coefficient in[-1, 1]. The matrix is symmetric with ones on the diagonal. Returns an empty dict if fewer than two parameters are provided.
- heterodyne.optimization.cmc.diagnostics.check_convergence(r_hat, ess_bulk, divergences, n_samples, n_chains, max_rhat=DEFAULT_MAX_RHAT, min_ess=DEFAULT_MIN_ESS, max_divergence_rate=DEFAULT_MAX_DIVERGENCE_RATE, num_shards=1)[source]
Check convergence criteria and return (status, warnings).
- heterodyne.optimization.cmc.diagnostics.create_diagnostics_dict(r_hat, ess_bulk, ess_tail, divergences, convergence_status, warnings, n_chains, n_warmup, n_samples, warmup_time, sampling_time, num_shards=1)[source]
Build a diagnostics dictionary suitable for JSON serialization.
- heterodyne.optimization.cmc.diagnostics.summarize_diagnostics(r_hat, ess_bulk, divergences, n_samples, n_chains, num_shards=1)[source]
Create human-readable diagnostics summary (homodyne-parity).
Mirrors homodyne
diagnostics.summarize_diagnostics. Used by CLI summary writers that emit a one-liner likeDiagnostics: R-hat(max)=1.02, ESS(min)=420, divergences=3 (0.3%).- Parameters:
r_hat (
dict[str,float]) – Per-parameter R-hat values. NaN values are skipped.ess_bulk (
dict[str,float]) – Per-parameter bulk ESS values. NaN values are skipped.divergences (
int) – Total divergence count across all chains and shards.n_samples (
int) – Posterior samples per chain.n_chains (
int) – Number of chains per shard.num_shards (
int) – Number of CMC shards (default 1 for non-sharded).
- Returns:
Single-line diagnostics summary suitable for log emission.
- Return type:
- heterodyne.optimization.cmc.diagnostics.get_convergence_recommendations(max_rhat, min_ess, divergences, n_samples, n_chains, num_shards=1)[source]
Generate actionable recommendations for convergence issues.
- heterodyne.optimization.cmc.diagnostics.log_analysis_summary(convergence_status, r_hat, ess_bulk, divergences, n_samples, n_chains, n_shards, shards_succeeded, execution_time)[source]
Log a formatted CMC analysis summary at INFO/ERROR level.
- Return type:
- class heterodyne.optimization.cmc.diagnostics.BimodalResult[source]
Bases:
objectResult of a bimodality test for a single parameter’s samples.
- param_name
Name of the parameter tested.
- is_bimodal
True if the 2-component GMM is favoured by BIC.
- bic_unimodal
BIC of the 1-component (unimodal) Gaussian mixture.
- bic_bimodal
BIC of the 2-component Gaussian mixture.
- delta_bic
bic_unimodal - bic_bimodal. Positive values favour the bimodal model.
- means
Tuple of the two component means (None when not bimodal).
- weights
Tuple of the two component mixing weights (None when not bimodal).
- param_name: str
- is_bimodal: bool
- bic_unimodal: float
- bic_bimodal: float
- delta_bic: float
- __init__(param_name, is_bimodal, bic_unimodal, bic_bimodal, delta_bic, means, weights)
- heterodyne.optimization.cmc.diagnostics.detect_bimodal(samples, param_name, bic_threshold=10.0, min_weight=0.0, min_separation=0.0)[source]
Fit 1- and 2-component Gaussian mixtures and compare BIC.
Uses scikit-learn’s
GaussianMixtureto estimate Bayesian Information Criterion for unimodal vs bimodal models. A positivedelta_biclarger thanbic_thresholdis treated as evidence for bimodality. Two optional post-conditions can tighten the criterion:min_weight: the minor mode must have at least this mixture weight (0–0.5). Filters spurious bimodality from very unequal modes.min_separation: the modes must be separated by at least this many sample standard deviations. Filters bimodality in flat or nearly uniform posteriors.
- Parameters:
samples (
ndarray) – 1-D array of posterior draws for the parameter.param_name (
str) – Name used for logging and result labelling.bic_threshold (
float) – Minimumdelta_bic(BIC_unimodal − BIC_bimodal) required to declare bimodality. Default 10.0 corresponds to strong evidence on the Raftery (1995) BIC scale.min_weight (
float) – Minimum weight of the minor mode for the result to be flagged as bimodal. Default 0.0 (no weight filter).min_separation (
float) – Minimum mode separation in sample standard deviations. Default 0.0 (no separation filter).
- Return type:
BimodalResult- Returns:
BimodalResultwith fitted statistics.- Raises:
ImportError – If scikit-learn is not installed, with a hint on how to add the optional dependency.
- heterodyne.optimization.cmc.diagnostics.check_shard_bimodality(shard_samples, bic_threshold=10.0, min_weight=0.0, min_separation=0.0)[source]
Detect bimodality for each parameter across all CMC shards.
Runs
detect_bimodal()for every (parameter, shard) combination and aggregates results per parameter.- Parameters:
- Return type:
- Returns:
Mapping from parameter name to a list of
BimodalResult, one entry per shard (in shard-index order).
- heterodyne.optimization.cmc.diagnostics.compute_nlsq_comparison_metrics(posterior_samples, nlsq_values)[source]
Compare posterior statistics against NLSQ point estimates.
For each parameter present in both
posterior_samplesandnlsq_values, computes:posterior_mean— mean of the flattened posterior draws.posterior_std— standard deviation of the flattened draws.nlsq_value— the NLSQ point estimate.z_score—|nlsq_value - posterior_mean| / posterior_std. NaN whenposterior_std == 0.within_hdi— 1.0 if the NLSQ value falls inside the 95 % HDI, 0.0 otherwise.
- Parameters:
- Return type:
- Returns:
Nested dict
result[param_name][metric_name] = value. Only parameters present in both inputs are included.
- heterodyne.optimization.cmc.diagnostics.compute_precision_analysis(posterior_samples)[source]
Compute precision metrics for each parameter’s posterior.
For each parameter, calculates:
mean— posterior mean.std— posterior standard deviation.cv— coefficient of variation =std / |mean|.infwhenmean == 0.hdi_width— width of the shortest interval containing 95 % of the posterior draws (highest-density interval).
- class heterodyne.optimization.cmc.diagnostics.ModeCluster[source]
Bases:
objectA single mode from bimodal consensus combination.
- mean
Per-parameter consensus mean for this mode.
- std
Per-parameter consensus std.
- weight
Fraction of shards supporting this mode (0-1).
- n_shards
Number of shards in this cluster.
- weight: float
- n_shards: int
- __init__(mean, std, weight, n_shards)
- class heterodyne.optimization.cmc.diagnostics.BimodalConsensusResult[source]
Bases:
objectResult of mode-aware consensus combination.
- modes
Mode clusters (typically 2) with per-mode consensus.
- modal_params
Parameter names that triggered bimodal detection.
- co_occurrence
Cross-parameter co-occurrence info.
- modes: list[ModeCluster]
- __init__(modes, modal_params, co_occurrence)
- heterodyne.optimization.cmc.diagnostics.summarize_cross_shard_bimodality(bimodal_detections, n_shards, consensus_means=None, significance_threshold=0.05)[source]
Aggregate per-shard bimodal detections into a cross-shard summary.
Groups detections by parameter, computes mode statistics (mean of lower modes, mean of upper modes), and checks whether the consensus posterior mean falls between the modes (density trough).
- Parameters:
bimodal_detections (
dict[str,list[BimodalResult]]) – Mapping from parameter name to a list ofBimodalResult(one per shard), as returned bycheck_shard_bimodality().n_shards (
int) – Total number of shards.consensus_means (
dict[str,float] |None) – Consensus posterior means for each parameter. Used to check if consensus falls in the density trough between modes.significance_threshold (
float) – Minimum separation significance (separation / pooled_std) for a bimodal split to be reported.
- Returns:
"per_param":{param_name -> {fraction_bimodal, lower_mode_mean, upper_mode_mean, separation, significance, consensus_in_trough}}"n_detections": Total bimodal detections across all params."n_shards": Number of shards.
- Return type:
- heterodyne.optimization.cmc.diagnostics.cluster_shard_modes(bimodal_detections, shard_samples, param_bounds=None)[source]
Jointly cluster shards into two mode populations.
Uses the parameters that show bimodal behaviour to build a per-shard feature vector, then runs a simple 2-means clustering (no sklearn dependency) to partition shards.
- Parameters:
bimodal_detections (
dict[str,list[BimodalResult]]) – Mapping from parameter name to a list ofBimodalResultas returned bycheck_shard_bimodality().shard_samples (
dict[int,dict[str,ndarray]]) – Per-shard samples mapping{shard_idx: {param_name: samples_array}}.param_bounds (
dict[str,tuple[float,float]] |None) – Optional per-parameter(lo, hi)bounds for normalization. If None, the global range across shards is used.
- Return type:
- Returns:
(cluster_0_indices, cluster_1_indices)where cluster 0 is the “lower” cluster (centroid with lower mean across features).
- heterodyne.optimization.cmc.diagnostics.log_precision_analysis(analysis, log_fn=None, tolerance_pct=20.0)[source]
Format and emit the CMC vs NLSQ precision analysis report.
Homodyne CMC parity helper. Consumes the output of
compute_precision_analysis()and produces a fixed-width report summarising z-scores, percent differences, uncertainty ratios, and posterior contraction ratios per parameter. Flags severe disagreements (z > 3,|diff| > tolerance_pct,ratio < 0.5x).- Parameters:
analysis (
dict[str,dict[str,float]]) – Per-parameter precision metrics fromcompute_precision_analysis(). Each value is a dict with keys includingcmc_mean,cmc_std,nlsq_value,z_score,relative_diff,uncertainty_ratio,posterior_contraction.log_fn (
Callable[[str],None] |None) – Logger function to emit the report through.Noneroutes through this module’s logger at INFO level.tolerance_pct (
float) – Percent-difference threshold for the[WARN]marker. Default20.0.
- Return type:
- Returns:
The formatted report as a multi-line string.
Reparameterization¶
Reference-time reparameterization for heterodyne CMC.
Breaks banana-shaped posteriors for correlated power-law pairs (D0/alpha) by sampling at a reference time t_ref where the product D(t_ref) = D0 * t_ref^alpha is well-constrained by data.
Adapted from homodyne/optimization/cmc/reparameterization.py for heterodyne’s 3 power-law pairs – D0_ref/alpha_ref, D0_sample/alpha_sample, and v0/beta.
- class heterodyne.optimization.cmc.reparameterization.ReparamConfig[source]
Bases:
objectConfiguration for reference-time reparameterization.
- enable_d_ref
Reparameterize D0_ref/alpha_ref pair.
- enable_d_sample
Reparameterize D0_sample/alpha_sample pair.
- enable_v_ref
Reparameterize v0/beta pair.
- t_ref
Reference time (geometric mean of dt and t_max).
- enable_d_ref: bool = True
- enable_d_sample: bool = True
- enable_v_ref: bool = True
- t_ref: float = 1.0
- is_reparameterized(name)[source]
Check if a parameter participates in reparameterization.
- Return type:
- get_reparam_name(prefactor)[source]
Get the reparameterized log-space name for a prefactor.
- Return type:
- __init__(enable_d_ref=True, enable_d_sample=True, enable_v_ref=True, t_ref=1.0)
- heterodyne.optimization.cmc.reparameterization.compute_t_ref(dt, t_max, fallback_value=None)[source]
Compute reference time as geometric mean of dt and t_max.
t_ref = sqrt(dt * t_max)
This places t_ref in the middle of the logarithmic time range, where the correlation function is most sensitive to the transport parameters.
- heterodyne.optimization.cmc.reparameterization.transform_nlsq_to_reparam_space(nlsq_values, nlsq_uncertainties, t_ref, config=None)[source]
Transform NLSQ point estimates to reparameterized space.
- For each enabled power-law pair (A0, alpha):
log_A_at_tref = log(A0) + alpha * log(t_ref)
- Delta-method propagation for uncertainty:
Var(log_A_at_tref) ≈ (sigma_A0/A0)² + (log(t_ref) * sigma_alpha)²
Non-reparameterized parameters pass through unchanged.
- Parameters:
- Return type:
- Returns:
Tuple of (transformed_values, transformed_uncertainties).
- heterodyne.optimization.cmc.reparameterization.transform_to_sampling_space(params, config)[source]
Transform physics-space parameters to sampling (reparam) space.
Used for initializing MCMC chains from NLSQ results.
- heterodyne.optimization.cmc.reparameterization.reparam_to_physics_jax(log_at_tref, alpha, t_ref)[source]
Back-transform reparameterized values to physics space (JAX).
A0 = exp(log_at_tref - alpha * log(t_ref))
- heterodyne.optimization.cmc.reparameterization.transform_to_physics_space(samples, config)[source]
Transform sampling-space posterior samples to physics space.
- Vectorized over sample dimension. For each enabled pair, computes:
A0 = exp(log_at_tref - alpha * log(t_ref))
Non-reparameterized parameters pass through.
- heterodyne.optimization.cmc.reparameterization.D_OFFSET_RATIO_MIN: float = -0.99
Minimum allowed ratio. Slightly above
-1so the impliedD_offset > -D_refkeeps the diffusion rate strictly positive at t_ref while preserving gradient information near the boundary.
- heterodyne.optimization.cmc.reparameterization.d_offset_to_ratio(d_offset, d_ref)[source]
Convert an absolute offset to the ratio representation.
d_offset_ratio = d_offset / d_ref. Returns0.0whend_refis non-positive (degenerate channel).- Return type:
- heterodyne.optimization.cmc.reparameterization.ratio_to_d_offset(ratio, d_ref)[source]
Reconstruct the absolute offset from the ratio representation.
d_offset = ratio * d_ref. Returns0.0whend_refis non-positive.- Return type:
- heterodyne.optimization.cmc.reparameterization.heterodyne_offset_ratios_from_physics(params, t_ref)[source]
Compute
D_offset_ratiofor both reference and sample channels.Evaluates each channel’s diffusion magnitude at
t_refasD_ref(t_ref) = D0 * t_ref**alphaand returns the ratiosD_offset_*_ratio = D_offset_* / D_ref(t_ref). Channels whoseD_ref(t_ref)is non-positive yield a0.0ratio so callers can fall back to direct sampling for that channel.
- heterodyne.optimization.cmc.reparameterization.heterodyne_physics_offsets_from_ratios(ratios, physics, t_ref)[source]
Inverse of
heterodyne_offset_ratios_from_physics().Given
D_offset_*_ratiovalues and the current physics-space(D0_*, alpha_*)parameters, returns the absoluteD_offset_*values consistent withD_ref(t_ref).
Scaling¶
Parameter scaling utilities for CMC warm-start initialization and contrast/offset estimation.
Smooth bounded parameter scaling for heterodyne CMC.
Replaces jnp.clip() (zero gradient at bounds) with tanh-based smooth bounding that is differentiable everywhere, allowing NUTS to adapt its mass matrix near parameter boundaries.
Adapted from homodyne/optimization/cmc/scaling.py.
- class heterodyne.optimization.cmc.scaling.ParameterScaling[source]
Bases:
objectScaling specification for a single parameter.
Defines the mapping between z-space (standard normal for MCMC) and original physics space with smooth bounding.
- name
Parameter name.
- center
NLSQ best-fit value (center of prior).
- scale
Prior width (NLSQ uncertainty × width_factor).
- low
Lower bound in physics space.
- high
Upper bound in physics space.
- name: str
- center: float
- scale: float
- low: float
- high: float
- to_normalized(value)[source]
Transform from physics space to z-space (normalized).
z = (value - center) / scale
- to_original(z_value)[source]
Transform from z-space to bounded original (physics) space.
raw = center + scale * z result = smooth_bound(raw, low, high)
- __init__(name, center, scale, low, high)
- heterodyne.optimization.cmc.scaling.smooth_bound(raw, low, high)[source]
Smooth bounding using tanh transform.
- Maps (-inf, +inf) → (low, high) via:
mid + half * tanh((raw - mid) / half)
This is differentiable everywhere, unlike jnp.clip() which has zero gradient at bounds and kills NUTS adaptation.
- heterodyne.optimization.cmc.scaling.smooth_bound_inverse(value, low, high)[source]
Inverse of smooth_bound for initialization.
- Recovers the raw (unbounded) value from a bounded value:
raw = mid + half * arctanh((value - mid) / half)
- heterodyne.optimization.cmc.scaling.compute_scaling_factors(space, nlsq_values=None, nlsq_uncertainties=None, width_factor=2.0)[source]
Build ParameterScaling for each varying parameter.
Uses NLSQ values as centers and NLSQ uncertainties × width_factor as scale. Falls back to bounds midpoint and range/6 when NLSQ results are unavailable.
- Parameters:
- Return type:
- Returns:
Dict mapping parameter name to ParameterScaling.
- heterodyne.optimization.cmc.scaling.log_scaling_factors(scalings)[source]
Log all scaling factors for debugging.
Emits a header at INFO level, then per-parameter details at DEBUG.
- heterodyne.optimization.cmc.scaling.transform_initial_values_to_z(initial_values, scalings)[source]
Transform initial values from physics space to z-space.
Only transforms parameters present in both initial_values and scalings.
- heterodyne.optimization.cmc.scaling.transform_samples_from_z(samples, scalings)[source]
Transform MCMC samples from z-space back to physics space.
Input keys must end with
"_z"; the suffix is stripped to recover the original parameter name. Only parameters with a matching entry in scalings are transformed.
Data Preparation¶
Data preparation utilities for CMC analysis.
Handles validation, JAX conversion, and sharding of correlation data for large-dataset MCMC workflows.
- class heterodyne.optimization.cmc.data_prep.ShardingStrategy[source]
Bases:
EnumStrategy for partitioning prepared data into shards.
- RANDOM
Randomly assign data points to shards with a fixed seed.
- CONTIGUOUS
Split along the time axis into contiguous blocks.
- STRATIFIED
Stratified by time range so each shard covers all epochs.
- ANGLE_BALANCED
Each shard receives proportional representation from every phi angle, preventing heterogeneous sub-posteriors.
- RANDOM = 'random'
- CONTIGUOUS = 'contiguous'
- STRATIFIED = 'stratified'
- ANGLE_BALANCED = 'angle_balanced'
- class heterodyne.optimization.cmc.data_prep.PreparedData[source]
Bases:
objectValidated and structured data container for CMC/NUTS sampling.
All arrays are kept as NumPy for compatibility with both JAX and SciPy backends. The caller converts to JAX inside the sampler.
- c2_data
Flattened observed correlation values, shape
(n_total,).
- weights
Per-element likelihood weights, shape
(n_total,)orNonewhen uniform weighting is used.
- time_array
Unique time values used to build the time grid, shape
(n_times,).
- phi_angles
Per-element phi angles (radians or degrees), shape
(n_total,).
- q
Wavevector magnitude in Å⁻¹.
- dt
Frame time step in seconds.
- metadata
Arbitrary key/value pairs (configuration, provenance, …).
- n_angles
Number of unique phi angles.
- n_times
Length of
time_array.
- c2_data: ndarray
- time_array: ndarray
- phi_angles: ndarray
- q: float
- dt: float
- n_angles: int = 0
- n_times: int = 0
- __init__(c2_data, weights, time_array, phi_angles, q, dt, metadata=<factory>, n_angles=0, n_times=0)
- class heterodyne.optimization.cmc.data_prep.PooledCMCData[source]
Bases:
objectPooled multi-phi data container for joint CMC (homodyne parity).
Mirrors
homodyne.optimization.cmc.data_prep.PreparedDataso heterodyne can run ONE NUTS pass conditioned on all phi angles with shared physics parameters. The pooled layout is (n_total,)-flat over angles + (t1, t2) grid; phi_indices maps each pooled point to its angle in phi_unique.- data
Pooled C2 values, shape
(n_total,).
- t1, t2
Pooled time coordinates, shape
(n_total,).
- phi
Pooled phi angles per point, shape
(n_total,).
- phi_unique
Sorted unique phi angles, shape
(n_phi,).
- phi_indices
Per-point index into
phi_unique, shape(n_total,).
- n_total
Number of pooled data points (after diagonal filtering).
- n_phi
Cardinality of
phi_unique.
- noise_scale
MAD-based noise estimate, used to centre the sampled sigma prior.
- data: ndarray
- t1: ndarray
- t2: ndarray
- phi: ndarray
- phi_unique: ndarray
- phi_indices: ndarray
- n_total: int
- n_phi: int
- noise_scale: float
- __init__(data, t1, t2, phi, phi_unique, phi_indices, n_total, n_phi, noise_scale)
- heterodyne.optimization.cmc.data_prep.validate_pooled_data(data, t1, t2, phi)[source]
Validate pooled arrays have matching length, finite values, sane shapes.
- Return type:
- heterodyne.optimization.cmc.data_prep.extract_phi_info(phi)[source]
Return
(phi_unique, phi_indices)with tolerance-aware matching.Mirrors
homodyne.optimization.cmc.data_prep.extract_phi_infoexactly. Forn_phi <= 256usesargmin(|phi - phi_unique|, axis=1)so float rounding doesn’t misassign points to neighbour angles; for larger phi counts falls back tosearchsortedwith a left-neighbour check.
- heterodyne.optimization.cmc.data_prep.prepare_mcmc_data(data, t1, t2, phi, filter_diagonal=True)[source]
Validate + filter pooled XPCS data for joint multi-phi CMC.
Mirrors
homodyne.optimization.cmc.data_prep.prepare_mcmc_data. Withfilter_diagonal=True(default), removest1 == t2rows using an epsilon-based comparison sized to the smallest positive dt in the arrays. The diagonal is loaded and plotted but excluded from likelihood fitting — same boundary contract as the t=0 row/col.- Parameters:
data (
ndarray) – Pooled (n_total,) arrays. Each entry is one (C2 value, t1, t2, phi) observation.t1 (
ndarray) – Pooled (n_total,) arrays. Each entry is one (C2 value, t1, t2, phi) observation.t2 (
ndarray) – Pooled (n_total,) arrays. Each entry is one (C2 value, t1, t2, phi) observation.phi (
ndarray) – Pooled (n_total,) arrays. Each entry is one (C2 value, t1, t2, phi) observation.filter_diagonal (
bool) – WhenTrue, drop entries where|t1 - t2| <= ε.
- Returns:
Filtered + indexed container ready for the NumPyro model.
- Return type:
PooledCMCData
- heterodyne.optimization.cmc.data_prep.shard_pooled_random(prepared, num_shards=None, max_points_per_shard=None, max_shards=100, seed=42)[source]
Shard pooled data into ~equal random subsets (homodyne parity).
Used when there is a single phi angle, or as the fallback for multi-angle data when angle-balanced sharding is not requested. Every data point lands in exactly one shard — no subsampling, no point loss.
Mirrors
homodyne.optimization.cmc.data_prep.shard_data_random.- Parameters:
prepared (
PooledCMCData) – Pooled multi-phi data container.num_shards (
int|None) – Explicit shard count. WhenNone, derived frommax_points_per_shard(ceil division), else 1.max_points_per_shard (
int|None) – Target points per shard used to derivenum_shardswhen it is not given.max_shards (
int) – Hard cap on shard count; when exceeded the shard size grows so all data still fits inmax_shardsshards.seed (
int) – Seed for the index shuffle (reproducible assignment).
- Return type:
list[PooledCMCData]- Returns:
List of
PooledCMCDatashards covering all points.
- heterodyne.optimization.cmc.data_prep.shard_pooled_angle_balanced(prepared, num_shards=None, max_points_per_shard=None, max_shards=500, min_angle_coverage=0.8, seed=42)[source]
Shard pooled data with proportional per-angle coverage (homodyne parity).
Preferred strategy for multi-angle datasets (
n_phi > 1). Each shard samples proportionally from every phi angle so sub-posteriors stay homogeneous — pure random sharding can leave shards with uneven angle coverage, producing high cross-shard parameter variance that Consensus MC then combines incorrectly.Mirrors
homodyne.optimization.cmc.data_prep.shard_data_angle_balanced.- Parameters:
prepared (
PooledCMCData) – Pooled multi-phi data container.num_shards (
int|None) – Explicit shard count. WhenNone, derived frommax_points_per_shard(ceil division), elsemax(1, n_phi).max_points_per_shard (
int|None) – Target points per shard used to derivenum_shardswhen it is not given.max_shards (
int) – Hard cap on shard count.min_angle_coverage (
float) – Fraction of angles each shard should contain; shards below this are logged as a diagnostic (not an error).seed (
int) – Seed for per-angle shuffles (reproducible assignment).
- Return type:
list[PooledCMCData]- Returns:
List of
PooledCMCDatashards with balanced angle coverage. Falls back toshard_pooled_random()whenn_phi == 1.
- heterodyne.optimization.cmc.data_prep.prepare_cmc_data(c2_data, sigma=None, weights=None)[source]
Validate and convert correlation data to JAX arrays.
Performs shape, dtype, NaN, and monotonicity checks on the input correlation matrix before transferring to JAX device memory.
- Parameters:
c2_data (
ndarray|Array) – Observed two-time correlation matrix. Must be 2-D and square (or 1-D for single-time slices).sigma (
ndarray|float|None) – Measurement uncertainty. Scalar broadcasts to all elements; array must matchc2_datashape. IfNone, returnsNoneand the caller is responsible for estimation.weights (
ndarray|Array|None) – Optional per-element weights for likelihood weighting. Must matchc2_datashape if provided.
- Return type:
- Returns:
Tuple of
(c2_jax, sigma_jax, weights_jax)ready for the NumPyro model.sigma_jaxis the scalar or array sigma (orNonepassthrough when input isNone).weights_jaxisNonewhen no weights are provided.- Raises:
ValueError – If data contains NaN, has mismatched shapes, or violates expected structure.
- heterodyne.optimization.cmc.data_prep.create_shard_grid(n_times, n_shards)[source]
Create time-index partitions for sharding a correlation matrix.
Divides the time axis into approximately equal chunks so that each shard can be processed independently (e.g., for consensus Monte Carlo on very large two-time matrices).
- Parameters:
- Return type:
- Returns:
List of
(start, stop)index pairs (half-open intervals) that partitionrange(n_times)inton_shardscontiguous chunks.- Raises:
ValueError – If
n_shards < 1orn_shards > n_times.
- heterodyne.optimization.cmc.data_prep.shard_correlation_data(c2_data, shard_grid)[source]
Split a two-time correlation matrix into shards along both axes.
Each shard is a sub-block of the full correlation matrix defined by the row and column index ranges in
shard_grid. Only diagonal blocks (same row and column shard) are returned, as off-diagonal blocks carry cross-shard correlations that are handled separately in the consensus step.- Parameters:
- Return type:
- Returns:
List of JAX arrays, one per shard, each of shape
(stop - start, stop - start).- Raises:
ValueError – If
c2_datais not 2-D or if shard indices are out of bounds.
- heterodyne.optimization.cmc.data_prep.merge_shard_results(shard_results)[source]
Combine per-shard posterior samples via simple concatenation.
For consensus Monte Carlo, this implements the naive pooling strategy where samples from each shard’s sub-posterior are concatenated. The caller may apply further weighting or density product corrections.
All shards must contain the same set of parameter names.
- Parameters:
shard_results (
list[dict[str,ndarray]]) – List of sample dictionaries, one per shard. Each dict maps parameter names to 1-D arrays of posterior draws.- Return type:
- Returns:
Merged dictionary with concatenated samples for each parameter.
- Raises:
ValueError – If shard results are empty or have mismatched keys.
- heterodyne.optimization.cmc.data_prep.prepare_data(raw_data, config=None)[source]
Validate, normalise, and package raw XPCS data for CMC sampling.
This is the main entry point for converting a raw data dictionary (as produced by the XPCS loader) into a
PreparedDatainstance suitable for NUTS or CMC workflows.- Parameters:
Dictionary with at least the following keys:
"c2_data"– array-like, shape(n_angles, n_t, n_t)or(n_t, n_t)."phi_angles"– 1-D array of azimuthal angles (degrees or radians), lengthn_angles."time_array"– 1-D monotonically increasing time axis."q"– scalar wavevector magnitude (Å⁻¹)."dt"– scalar frame time step (seconds).
Optional keys:
"weights"– array matchingc2_data, per-element likelihood weights."mask"– boolean array matchingc2_data;Truewhere data should be excluded.
config (
dict[str,Any] |None) –Optional configuration dictionary. Recognised keys:
"normalize_weights"(bool, defaultTrue) – rescale weights so their mean equals 1."require_positive_diagonal"(bool, defaultTrue) – raise if any diagonal element <= 0.
- Return type:
PreparedData- Returns:
PreparedDataready forcreate_shards().- Raises:
ValueError – If required keys are missing, arrays have unexpected shapes, or data contains NaN / non-finite values.
KeyError – If a required key is absent from
raw_data.
- heterodyne.optimization.cmc.data_prep.create_shards(prepared_data, n_shards, strategy=ShardingStrategy.ANGLE_BALANCED, *, seed=42)[source]
Split a
PreparedDatainstance inton_shardssub-datasets.Each shard is itself a
PreparedDatawith the sameq,dt, andtime_arrayas the parent but containing only a subset of the pooled data points.- Parameters:
- Return type:
list[PreparedData]- Returns:
List of
n_shardsPreparedDatainstances.- Raises:
ValueError – If
n_shards < 1or strategy is unsupported.
- heterodyne.optimization.cmc.data_prep.validate_shard_data(shard)[source]
Validate a single shard for common data quality issues.
Checks performed:
No NaN or non-finite values in
c2_data.Shape consistency between
c2_data,phi_angles, andweights(when present).At least one data point.
Positive values in the subset of elements corresponding to the diagonal of the original two-time matrix.
- Parameters:
shard (
PreparedData) –PreparedDatashard to validate.- Raises:
ValueError – On any detected integrity issue.
- Return type:
- heterodyne.optimization.cmc.data_prep.estimate_shard_memory(shard)[source]
Estimate the device memory footprint of a shard in bytes.
Counts all NumPy arrays stored in the shard, using their actual
nbytesattribute. This is a lower bound because JAX may add internal buffers during JIT compilation, but it is accurate enough for pre-flight capacity checks.- Parameters:
shard (
PreparedData) –PreparedDatashard.- Return type:
- Returns:
Estimated memory in bytes.
CMC I/O¶
Full shard-level and aggregate I/O pipeline for CMC results. Supports
NPZ sample arrays, ArviZ InferenceData, JSON parameter/diagnostics
files, and fitted-data arrays.
Shard I/O for CMC (Consensus Monte Carlo) results.
Provides functions to persist and reload posterior samples as .npz
archives and ArviZ InferenceData objects as NetCDF files.
- heterodyne.optimization.cmc.io.save_shard_results(results, output_dir, shard_id)[source]
Save posterior samples for a single shard as a
.npzarchive.The file is written to
<output_dir>/shard_<shard_id>.npz.
- heterodyne.optimization.cmc.io.load_shard_results(output_dir, shard_id)[source]
Load posterior samples for a single shard.
- heterodyne.optimization.cmc.io.list_shards(output_dir)[source]
Discover saved shard IDs in output_dir.
Scans for files matching
shard_<N>.npzand returns a sorted list of the integer shard IDs.
- heterodyne.optimization.cmc.io.save_inference_data(idata, path)[source]
Save an ArviZ InferenceData object as a NetCDF file.
- Parameters:
- Return type:
- Returns:
Path to the saved file.
- Raises:
ImportError – If
arvizis not installed.
- heterodyne.optimization.cmc.io.load_inference_data(path)[source]
Load an ArviZ InferenceData object from a NetCDF file.
- Parameters:
path (
str|Path) – Path to a NetCDF file previously written bysave_inference_data().- Return type:
- Returns:
ArviZ InferenceData object.
- Raises:
FileNotFoundError – If the file does not exist.
ImportError – If
arvizis not installed.
- heterodyne.optimization.cmc.io.save_samples_npz(result, output_path)[source]
Save posterior samples as a compressed NPZ archive (ArviZ-compatible).
Shape stored:
posterior_samplesis(n_chains, n_samples, n_params).- Return type:
- heterodyne.optimization.cmc.io.load_samples_npz(input_path)[source]
Load samples NPZ and return a plain Python dict.
- Raises:
FileNotFoundError – If the file does not exist.
ValueError – If the suffix is not
.npzor the path is not a regular file.
- Return type:
- heterodyne.optimization.cmc.io.samples_to_arviz(samples_data)[source]
Convert loaded samples dict to ArviZ InferenceData.
- heterodyne.optimization.cmc.io.save_fitted_data_npz(result, c2_exp, c2_fitted, c2_fitted_std, t1, t2, phi_angles, q, output_path)[source]
Save fitted C2 data in NLSQ-compatible NPZ format.
- Return type:
- heterodyne.optimization.cmc.io.save_parameters_json(result, output_path)[source]
Save posterior statistics per parameter to JSON.
NaN -> null, Inf -> “Infinity”/”-Infinity”.
- Return type:
- heterodyne.optimization.cmc.io.save_diagnostics_json(result, output_path, warnings=None)[source]
Save convergence diagnostics to JSON.
- Return type:
CMC Plotting¶
ArviZ-backed diagnostic plots for CMC posterior samples.
ArviZ-based MCMC diagnostic plots for CMC results.
All functions accept ArviZ InferenceData objects and return
matplotlib figures. arviz is imported with a try/except guard
so that the rest of the package does not hard-depend on it.
- heterodyne.optimization.cmc.plotting.plot_trace_summary(idata, var_names=None, figsize=None)[source]
ArviZ trace plot with marginal posteriors.
- heterodyne.optimization.cmc.plotting.plot_pair_plot(idata, var_names=None, divergences=True)[source]
ArviZ pair plot with optional divergence markers.
- heterodyne.optimization.cmc.plotting.plot_posterior_predictive(idata, c2_data, times, ax=None)[source]
Overlay posterior-predictive draws on experimental data.
If idata contains a
posterior_predictivegroup with a variable named"c2_pred", the 5th/95th percentile envelope is drawn. Otherwise a message is displayed.
- heterodyne.optimization.cmc.plotting.plot_diagnostics_summary(idata)[source]
Combined R-hat, ESS, and BFMI diagnostic panels.
Creates a three-panel figure: 1. R-hat per parameter (bar chart). 2. Bulk ESS per parameter (bar chart). 3. BFMI per chain (bar chart).
- Parameters:
idata (
object) – ArviZ InferenceData object.- Return type:
Figure- Returns:
Matplotlib Figure.
- heterodyne.optimization.cmc.plotting.plot_forest(idata, output_dir, var_names=None, figsize=DEFAULT_FIGSIZE, dpi=DEFAULT_DPI)[source]
Save an ArviZ forest plot to
output_dir / forest_plot.png.Shows posterior intervals (94% HDI by default) for each parameter. Homodyne-parity helper.
- Return type:
- heterodyne.optimization.cmc.plotting.plot_energy(idata, output_dir, figsize=(10, 6), dpi=DEFAULT_DPI)[source]
Save an ArviZ energy plot to
output_dir / energy_plot.png.Falls back gracefully when
sample_statslacks an energy field by writing a placeholder image with an explanatory message. Homodyne parity helper that handles NumPyro’spotential_energynaming.- Return type:
- heterodyne.optimization.cmc.plotting.plot_autocorr(idata, output_dir, var_names=None, figsize=DEFAULT_FIGSIZE, dpi=DEFAULT_DPI)[source]
Save an ArviZ autocorrelation plot. Homodyne-parity helper.
- Return type:
- heterodyne.optimization.cmc.plotting.plot_rank(idata, output_dir, var_names=None, figsize=DEFAULT_FIGSIZE, dpi=DEFAULT_DPI)[source]
Save an ArviZ rank plot. Helps detect chain-mixing issues.
- Return type:
- heterodyne.optimization.cmc.plotting.plot_ess(idata, output_dir, var_names=None, figsize=(10, 6), dpi=DEFAULT_DPI)[source]
Save an ArviZ ESS-evolution plot.
- Return type:
- heterodyne.optimization.cmc.plotting.generate_diagnostic_plots(idata, output_dir, var_names=None, dpi=DEFAULT_DPI)[source]
Generate the full homodyne-parity diagnostic plot suite.
Writes
forest_plot.png,energy_plot.png,autocorr_plot.png,rank_plot.png, andess_plot.pngtooutput_dir. Individual failures are isolated — one broken plot does not abort the others.- Parameters:
- Return type:
- Returns:
Mapping
plot_kind -> output_pathfor each plot that succeeded.
Backends¶
Execution backends for running NUTS chains across CPU cores, HPC
cluster nodes, and multi-device pjit configurations. All backends
share the MCMCBackend
protocol and are selected automatically via
select_backend() based on
the chain_method setting.
Backend Selection¶
Abstract base and factory for MCMC execution backends.
Includes Consensus Monte Carlo utilities for combining posteriors from independent MCMC shards via inverse-variance (precision) weighting.
- class heterodyne.optimization.cmc.backends.base.BackendCapabilities[source]
Bases:
objectStatic description of what an MCMC backend can do.
Used by the backend selection logic and resource estimation code to choose the best available backend at runtime without instantiating every candidate.
- supports_sharding
True if the backend can distribute data shards across workers or devices.
- supports_parallel_chains
True if chains can run concurrently (e.g. via
pmapor a worker pool).
- max_parallel_shards
Maximum number of shards the backend can handle simultaneously. 1 means strictly sequential.
- supports_sharding: bool = False
- supports_parallel_chains: bool = True
- max_parallel_shards: int = 1
- __init__(supports_sharding=False, supports_parallel_chains=True, max_parallel_shards=1)
- class heterodyne.optimization.cmc.backends.base.MCMCBackend[source]
Bases:
ProtocolProtocol for MCMC execution backends.
Each backend wraps NumPyro’s MCMC machinery with a CPU execution strategy (sequential single-device or parallel multi-device).
- run(model, config, rng_key, init_params=None)[source]
Run MCMC sampling and return posterior samples.
- Parameters:
model (
Callable[...,Any]) – NumPyro model function (callable with no required args).config (
CMCConfig) – CMC configuration with sampling hyperparameters.rng_key (
Array) – JAX PRNG key for reproducibility.init_params (
dict[str,Array] |None) – Optional initial parameter values per chain. Keys are parameter names; values have shape(num_chains,).
- Return type:
- Returns:
Dictionary mapping parameter names to sample arrays. Each array has shape
(num_samples * num_chains,)for ungrouped samples, matching NumPyro’s defaultget_samples()behavior.
- __init__(*args, **kwargs)
- class heterodyne.optimization.cmc.backends.base.CMCBackend[source]
Bases:
ABCAbstract base class for CMC execution backends.
Concrete subclasses implement CPU MCMC execution strategies (sequential, multi-device parallel, worker-pool, etc.). Subclasses must override
run,get_capabilities,validate_resources,estimate_memory, andcleanup.- abstractmethod run(model, config, rng_key, init_params=None)[source]
Run MCMC sampling and return posterior samples.
- Parameters:
- Return type:
- Returns:
Dictionary mapping parameter names to flat sample arrays.
- abstractmethod get_capabilities()[source]
Return a static description of this backend’s capabilities.
- Return type:
BackendCapabilities- Returns:
Frozen
BackendCapabilitiesdataclass.
- abstractmethod validate_resources()[source]
Check that required hardware and software resources are available.
- Raises:
RuntimeError – If a required resource (device, library, memory) is unavailable.
- Return type:
- abstractmethod estimate_memory(n_data, n_params, n_chains)[source]
Estimate peak memory consumption for a sampling run.
The estimate is intentionally conservative (upper-bound) to help callers decide whether to proceed or reduce chain count / shard size.
- heterodyne.optimization.cmc.backends.base.select_backend(config)[source]
Select the appropriate MCMC backend.
Selection order (mirrors homodyne
select_backendsemantics):backend_name == "pbs"→PBSBackend(raises ifqsubis not on PATH).backend_name == "multiprocessing"or legacy alias"jax"→MultiprocessingBackend.backend_name == "pjit"→PjitBackend.backend_name == "cpu"→CPUBackend.backend_name == "auto"(default): heuristic (MultiprocessingBackendifn_chains >= 3and at least 2 physical workers, thenPjitBackendwhenlen(jax.devices()) > 1, elseCPUBackend).
- Parameters:
config (
CMCConfig) – CMC configuration.- Return type:
MCMCBackend- Returns:
An instantiated backend ready for
run().- Raises:
ValueError –
backend_namenot in the supported set.
- class heterodyne.optimization.cmc.backends.base.ShardPosterior[source]
Bases:
objectPosterior summary from a single MCMC shard.
- mean
Parameter means, shape
(n_params,).
- covariance
Covariance matrix, shape
(n_params, n_params).
- n_samples
Number of effective posterior samples in this shard.
- shard_id
Optional identifier for logging / diagnostics.
- mean: ndarray
- covariance: ndarray
- n_samples: int = 0
- shard_id: int = 0
- __init__(mean, covariance, n_samples=0, shard_id=0)
- heterodyne.optimization.cmc.backends.base.consensus_mc(shard_posteriors)[source]
Combine shard posteriors using Consensus Monte Carlo.
Each shard’s posterior is summarised by its mean and covariance. The combined posterior is the precision-weighted average:
Λ_combined = Σ_k Λ_k (sum of precisions) μ_combined = Λ_combined⁻¹ Σ_k Λ_k μ_k
This is exact when the sub-posteriors are Gaussian and the prior factorises across shards (the “embarrassingly parallel” regime).
- Parameters:
shard_posteriors (
list[ShardPosterior]) – List ofShardPosterior, one per shard. All must have the same dimensionality.- Return type:
- Returns:
Tuple of
(combined_mean, combined_covariance)wherecombined_meanhas shape(n_params,)andcombined_covariancehas shape(n_params, n_params).- Raises:
ValueError – If fewer than 1 shard is provided or shapes are inconsistent.
- heterodyne.optimization.cmc.backends.base.robust_consensus_mc(shard_posteriors, *, outlier_sigma=3.0)[source]
Combine shard posteriors with outlier-resistant weighting.
Like
consensus_mc()but first identifies and downweights outlier shards whose means deviate from the cross-shard median by more thanoutlier_sigmastandard deviations.Outlier detection uses the median absolute deviation (MAD) of per-shard means for each parameter. Shards flagged as outliers on any parameter have their precision scaled by
1 / n_shards(i.e. they contribute but don’t dominate).- Parameters:
- Return type:
- Returns:
Tuple of
(combined_mean, combined_covariance).- Raises:
ValueError – If fewer than 2 shards are provided (need ≥ 2 for robust statistics) or shapes are inconsistent.
- heterodyne.optimization.cmc.backends.base.combine_shard_samples(shard_samples, *, method='consensus_mc', chunk_size=500, seed=42)[source]
Combine raw posterior samples from multiple CMC shards.
Homodyne CMC parity wrapper. Each shard contributes a dictionary of per-parameter posterior draws (typically shape
(n_chains, n_samples)); the function returns a single combined dictionary with the same per-parameter shape.- Pathways:
Single shard — returned unchanged.
K ≤ chunk_size — single-pass precision-weighted combination on per-shard (mean, variance) summaries with non-finite filtering and degenerate-shard exclusion (variance < 1e-6 × median variance).
K > chunk_size — moment-accumulation across chunks, single Gaussian draw at the end. Avoids the recursive precision-inflation bug that arose from re-combining synthetic intermediate samples.
- Parameters:
shard_samples (
list[dict[str,ndarray]]) – List of per-shard sample dicts. All must share the same parameter-name set and per-parameter shape(C, S).method (
str) – Combination method."consensus_mc"/"robust_consensus_mc"/"weighted_gaussian"/"auto"all map to precision-weighted Gaussian recombination at this granularity."simple_average"averages without precision weighting.chunk_size (
int) – Threshold for hierarchical mode. Default500keeps per-step peak memory bounded.seed (
int) – PRNG seed for the synthetic Gaussian draw at the end.
- Return type:
- Returns:
Combined samples dict with the same keys/shapes as the per-shard inputs.
- Raises:
ValueError – If
shard_samplesis empty or shards disagree on parameter names.
- heterodyne.optimization.cmc.backends.base.combine_shard_samples_bimodal(shard_samples, *, cluster_param=None, method='consensus_mc', seed=42)[source]
Mode-aware combination — cluster shards by posterior mode then combine within cluster.
Homodyne CMC parity helper for multimodal posteriors. Uses a simple 2-means clustering on per-shard posterior means of
cluster_param(default: the first parameter) to partition shards into two modes, then runscombine_shard_samples()within each cluster.- Parameters:
shard_samples (
list[dict[str,ndarray]]) – Per-shard sample dictionaries.cluster_param (
str|None) – Parameter name to use for clustering.Nonepicks the first parameter alphabetically.method (
str) – Combination method passed tocombine_shard_samples().seed (
int) – PRNG seed for clustering tiebreaker and final draws.
- Return type:
- Returns:
Mapping
cluster_id -> combined_samples_dict. Cluster ids are"mode_low"and"mode_high"ordered by mean value ofcluster_param. If clustering fails (e.g. fewer than 2 shards per mode), all shards are combined into a single"mode_low"bucket.
CPU Backend (Sequential / Vectorized)¶
CPU-optimized MCMC execution backend.
Runs NUTS chains sequentially (one at a time) to avoid memory pressure on CPU-only machines where all chains share the same memory pool.
- class heterodyne.optimization.cmc.backends.cpu_backend.CPUBackend[source]
Bases:
CMCBackendCPU-optimized MCMC backend using sequential chain execution.
Runs each MCMC chain one at a time via NumPyro’s
chain_method="sequential"to keep peak memory usage proportional to a single chain. This is the recommended backend for single-device CPU machines.- run(model, config, rng_key, init_params=None)[source]
Run NUTS sampling with sequential chain execution.
- Parameters:
- Return type:
- Returns:
Dictionary of posterior samples from all chains.
- Raises:
RuntimeError – If MCMC sampling fails.
- get_capabilities()[source]
Return CPU backend capabilities.
The CPU backend runs chains sequentially (one at a time) and does not support cross-device sharding.
- Return type:
BackendCapabilities- Returns:
BackendCapabilitiesreflecting sequential CPU execution.
- validate_resources()[source]
Verify that CPU resources are available for sampling.
Checks that at least one JAX CPU device is accessible.
- Raises:
RuntimeError – If no CPU device is found via
jax.devices.- Return type:
- estimate_memory(n_data, n_params, n_chains)[source]
Estimate peak CPU memory for a single sequential chain.
Because chains run sequentially only one chain’s state is live at any moment, so
n_chainsdoes not multiply peak usage.The formula accounts for:
Flat parameter storage per draw:
n_paramsfloat64 scalars.Gradient / momentum buffers: same size as parameters.
Sample storage for the completed chain:
num_samples * n_params.Data residual buffer:
n_datafloat64 scalars.A conservative overhead multiplier (
_CPU_MEMORY_OVERHEAD_FACTOR) for JAX tracing buffers and NumPyro auxiliary state.
Multiprocessing Backend¶
Parallel shard evaluation across CPU cores using Python
multiprocessing. Each worker process runs an independent NUTS
chain on its own shard, with full JAX persistent compilation cache
support.
Multiprocessing backend for CMC sharded MCMC execution.
This module provides parallel NUTS execution using Python’s multiprocessing module for CPU-based parallelism across CMC shards. Each shard runs as a separate spawned process with its own JAX initialization, avoiding JAX shared-state issues across forked processes.
Key design decisions:
mp_context="spawn"(not fork): JAX cannot be safely shared across fork. Spawned workers re-initialize JAX from scratch.All NumPyro imports inside worker functions: spawn safety requires that no JAX/NumPyro state exists at import time in the child process.
Shared memory for common data:
SharedDataManagerplaces config, parameter-space state, and per-shard arrays in shared memory once, avoiding redundant serialization overhead through spawn.LPT scheduling: shards dispatched highest-cost-first to minimize tail latency on identical parallel workers.
Heartbeat thread inside each worker: emits liveness pings so the parent can detect frozen processes and apply
heartbeat_timeout.Adaptive polling: poll interval grows when no shard has completed recently, shrinking CPU overhead during long-running shards.
Optimizations carried over from homodyne v2.22.2:
Batch PRNG key generation: pre-generate all shard keys in one JAX call.
Per-shard shared memory (packed format): 4 segments total regardless of shard count, avoiding fd exhaustion.
deque for pending shards: O(1) popleft instead of O(n) list.pop(0).
Persistent compilation cache via
jax.config.update(env var alone insufficient in JAX 0.8+,min_compile_timelowered to 0).
This backend is selected when config.num_chains >= 3, or when
config.backend_name == "multiprocessing".
- class heterodyne.optimization.cmc.backends.multiprocessing.ArraySpec[source]
Bases:
objectSchema for one packed shared-memory array key.
- expected_dtype
Documented dtype. Advisory only — the packer uses the actual array’s dtype. Logged in shape-mismatch errors.
- allow_none
Whether a caller may pass
Nonefor this key, in which case it is stored as a zero-length sentinel.
- description
Human-readable label, surfaced in error messages so shape/size mismatches identify the offending physical array (e.g. “two-time correlation matrix”, not just “c2_data”).
- expected_dtype: str
- allow_none: bool
- description: str
- __init__(expected_dtype, allow_none, description)
- class heterodyne.optimization.cmc.backends.multiprocessing.SharedDataManager[source]
Bases:
objectManages shared memory blocks for data common to all CMC shards.
Uses
multiprocessing.shared_memoryto share config dicts, parameter- space state, initial values, and per-shard arrays across spawned worker processes, avoiding redundant serialisation per shard.Serialization note: uses
pickleinternally for trusted internal dicts only (CMCConfig.to_dict(), parameter-space dict). This matches the existing multiprocessing behaviour which also serialises all process arguments. External/untrusted data is never serialised here.Must be used as a context manager or
cleanup()called in afinallyblock to avoid leaked shared memory segments on Linux.- _shared_blocks
All allocated
SharedMemorysegments.
- _refs
Named references returned to callers.
- __init__()[source]
- create_shared_bytes(name, data)[source]
Store raw bytes in a shared memory segment.
- create_shared_array(name, array)[source]
Store a numpy array in a shared memory segment.
- create_shared_dict(name, d)[source]
Serialise a trusted internal dict into shared memory.
Only used for
CMCConfig.to_dict()and parameter-space dicts. External/untrusted data is never passed here.
- create_shared_shard_arrays(shard_data_list)[source]
Place per-shard numpy arrays into packed shared memory.
Instead of creating one
SharedMemorysegment per array per shard (n_shards * 4= many file descriptors), this concatenates all shard arrays for each key into a single shared memory block. Onlylen(_SHARD_ARRAY_KEYS)segments are created regardless of shard count.- Parameters:
shard_data_list (
list[dict[str,Any]]) – List of shard data dicts, each containing numpy arrays keyed by_SHARD_ARRAY_KEYSplus a scalarnoise_scale. Arrays forsigmaandweightsmay beNone, in which case a zero-length sentinel is stored.- Return type:
- Returns:
List of lightweight shard references (shm names + offsets). Each ref dict is small enough to serialise cheaply through spawn.
- heterodyne.optimization.cmc.backends.multiprocessing.run_joint_pooled_shards_parallel(payloads, *, n_workers, num_chains, progress_bar=True, per_shard_timeout=7200)[source]
Run pooled-model shard payloads across a spawn process pool.
Reuses the proven
_init_worker_jaxinitializer (float64, compilation cache, OpenMP thread pinning). Each element ofpayloadsis the kwargs dict forheterodyne.optimization.cmc.core._joint_pooled_nuts_run(). Returns the per-shardCMCResultobjects in input order.Progress is reported two ways so long runs never block silently: a
tqdmbar (interactive terminals) AND periodiclogger.infoheartbeats (visible in log files, where tqdm is not). The heartbeat fires everyheartbeatseconds while waiting and reports elapsed time, completed/total shards, and time since the last completion.per_shard_timeoutbounds the wait: if no shard completes withinper_shard_timeoutseconds (every in-flight worker has exceeded the budget), the pool is terminated and the remaining shards are returned as failed placeholders. This prevents the het_491ee368 failure mode where a divergence-storming shard ran 21 000 s (3x the 7200 s budget) unkilled. Terminated shards become failed sub-posteriors that Consensus MC drops via itsconvergence_passedgate — the run degrades, it does not hang.
- class heterodyne.optimization.cmc.backends.multiprocessing.LPTScheduler[source]
Bases:
objectLongest Processing Time scheduler for load balancing across cores.
Assigns shards to workers based on estimated computation time using the LPT (Longest Processing Time first) heuristic. The highest-cost shards are dispatched first so that the remaining shards finishing last are the cheapest, minimising overall tail latency.
This is a simple greedy scheduler; it does not account for real-time feedback about actual execution durations.
- n_workers
Number of parallel workers.
- _shard_order
Deque of shard indices sorted by descending cost.
- __init__(shard_costs, n_workers)[source]
Initialise the LPT scheduler.
- classmethod from_shard_data(shard_data_list, n_workers, n_params=_N_PARAMS_HETERODYNE, n_samples=1000)[source]
Build an
LPTSchedulerfrom raw shard data dicts.Cost is estimated via
_estimate_shard_time()for each shard.- Parameters:
- Return type:
LPTScheduler- Returns:
Configured
LPTScheduler.
- next_shard()[source]
Pop and return the next shard index to dispatch.
- class heterodyne.optimization.cmc.backends.multiprocessing.MultiprocessingBackend[source]
Bases:
CMCBackendCMC backend that parallelises NUTS across shards via spawned processes.
Each shard runs as an independent Python process so that JAX is initialised fresh per shard — avoiding the shared-state issues that arise when forking a process that already has a JAX runtime loaded.
Shared data (config, parameter space, initial values, per-shard arrays) is placed in
SharedDataManageronce in the parent and accessed via_load_shared_*in each child, minimising serialisation overhead through spawn.The
run()method provides the standard single-shardCMCBackendcontract (sequential chain execution, no subprocess overhead). For multi-shard CMC, userun_shards(), which orchestrates the full parallel dispatch loop.- n_workers
Number of concurrent worker processes.
- spawn_method
Multiprocessing start method (always
"spawn").
- _shared_mgr
Active
SharedDataManagerduringrun_shards();Noneotherwise.
- __init__(n_workers=None, spawn_method='spawn')[source]
Initialise the multiprocessing backend.
- Parameters:
- Raises:
ValueError – If
spawn_method="fork"is requested.
- n_workers: int
- spawn_method: str
- run(model, config, rng_key, init_params=None)[source]
Run NUTS sampling for a single shard (standard CMCBackend contract).
For multi-shard CMC, call
run_shards()instead. This method provides API parity withCPUBackendandPjitBackendusing sequential chain execution and no subprocess overhead.- Parameters:
- Return type:
- Returns:
Dictionary of posterior samples from all chains.
- Raises:
RuntimeError – If MCMC sampling fails.
- get_capabilities()[source]
Return multiprocessing backend capabilities.
- Return type:
BackendCapabilities- Returns:
BackendCapabilitiesindicating sharding support, parallel shards equal ton_workers.
- validate_resources()[source]
Check that CPU resources and multiprocessing are available.
- Raises:
RuntimeError – If no JAX CPU device is found or if the
multiprocessingmodule cannot create a spawn context.- Return type:
- estimate_memory(n_data, n_params, n_chains)[source]
Estimate peak memory for all concurrent workers combined.
Conservative upper bound: each worker holds one chain’s live state (params + momentum + gradients) plus the data buffer. Workers run chains sequentially so
n_chainsdoes not multiply within a single worker.
- cleanup()[source]
Release shared memory and any other resources.
Idempotent — safe to call multiple times.
- Return type:
- run_shards(shards, config, initial_values=None, parameter_space=None, prior_width_multiplier=1.0, nlsq_uncertainties=None, nlsq_prior_width_factor=2.0, progress_bar=True)[source]
Run NUTS in parallel across all CMC shards.
Orchestrates the full parallel dispatch loop:
Allocate shared memory for config, parameter space, initial values, and per-shard arrays.
Pre-generate all PRNG keys in the parent process.
Dispatch shards to worker processes in LPT order.
Drain the result queue with adaptive polling.
Enforce per-shard and heartbeat timeouts.
Validate and return successful shard results.
- Parameters:
shards (
list[dict[str,Any]]) – List of shard dicts. Each must contain at minimumc2_data(numpy array). Optional keys:sigma,t,weights,noise_scale,q,dt,phi_angle,contrast,offset,reparam_config_dict.config (
CMCConfig) – CMC configuration with NUTS hyperparameters and timeout settings.initial_values (
dict[str,Any] |None) – Optional NLSQ warm-start values shared across all shards.parameter_space (
Any|None) – OptionalParameterSpaceinstance. Its internal config dict is serialised into shared memory.progress_bar (
bool) – Whether to show a tqdm progress bar.
- Return type:
- Returns:
List of validated successful result dicts, one per succeeded shard. Each dict contains
shard_idx,samples,n_chains,n_samples,param_names,extra_fields,duration, andstats.- Raises:
ValueError – If
shardsis empty.RuntimeError – If all shards fail, or if the success rate falls below
config.min_success_rate.
PBS/HPC Backend¶
Submits per-shard NUTS chains as PBS job-array tasks for use on HPC cluster nodes (e.g. ALCF Polaris, NERSC Perlmutter).
PBS/Torque job submission backend for Consensus Monte Carlo.
Submits per-shard MCMC sampling as PBS batch jobs, collects results from completed jobs, and combines them. Designed for HPC clusters running PBS Professional or Torque where each shard runs on a separate node.
Usage:
from heterodyne.optimization.cmc.backends.pbs import PBSBackend, PBSConfig
cfg = PBSConfig(queue="large", walltime="04:00:00", nodes=1, ppn=8)
backend = PBSBackend(pbs_config=cfg)
samples = backend.run(model_fn, cmc_config, rng_key)
- class heterodyne.optimization.cmc.backends.pbs.PBSConfig[source]
Bases:
objectConfiguration for PBS/Torque job submission.
- queue
Target PBS queue name (e.g.
"batch").
- walltime
Maximum wall-clock time in
HH:MM:SSformat.
- nodes
Number of nodes per shard job.
- ppn
Processors per node.
- memory
Memory request per node (e.g.
"4gb").
- python_executable
Python interpreter accessible in the PBS job environment (full path or bare name like
"python3").
- working_dir
Directory for temporary files. Defaults to
tempfile.gettempdir()whenNone.
- extra_pbs_directives
Raw
#PBSlines injected verbatim after the standard resource block.
- poll_interval
Seconds between
qstatpolls.
- max_retries
Re-submission attempts for failed shards.
- cleanup_on_success
Delete temporary files after successful runs.
- queue: str = 'batch'
- walltime: str = '01:00:00'
- nodes: int = 1
- ppn: int = 1
- memory: str = '4gb'
- python_executable: str = 'python'
- poll_interval: float = 30.0
- max_retries: int = 2
- cleanup_on_success: bool = True
- __post_init__()[source]
Codex W5: validate every shell-bound field at construction time.
Misconfigured PBS jobs fail here (cheap) instead of after a qsub round-trip (expensive) or — worse — after the malformed script runs. Each validator raises
ValueErrorwith a message that names the field and the offending characters.- Return type:
- __init__(queue='batch', walltime='01:00:00', nodes=1, ppn=1, memory='4gb', python_executable='python', working_dir=None, extra_pbs_directives=<factory>, poll_interval=30.0, max_retries=2, cleanup_on_success=True)
- class heterodyne.optimization.cmc.backends.pbs.ShardResult[source]
Bases:
objectResult from a single shard MCMC job.
- shard_id
Zero-based shard index.
- samples
Posterior samples keyed by parameter name.
- job_id
PBS job identifier string.
- success
Truewhen sampling completed without errors.
- error_message
Populated when
successisFalse.
- shard_id: int
- job_id: str
- success: bool = True
- error_message: str = ''
- __init__(shard_id, samples, job_id, success=True, error_message='')
- class heterodyne.optimization.cmc.backends.pbs.PBSBackend[source]
Bases:
CMCBackendPBS/Torque backend for distributed CMC sampling.
Each data shard is submitted as an independent PBS batch job. The main process polls
qstatuntil all jobs terminate, then reads per-shard.npzresult files and concatenates the samples.- Parameters:
pbs_config (
PBSConfig|None) – PBS resource and scheduling options. Defaults to aPBSConfig()withqueue="batch"whenNone.
- __init__(pbs_config=None)[source]
- run(model, config, rng_key, init_params=None)[source]
Submit per-shard PBS jobs, wait for completion, return combined samples.
- Parameters:
- Return type:
- Returns:
Dictionary mapping parameter names to concatenated numpy arrays.
- Raises:
RuntimeError – If PBS is unavailable or any shard fails.
TimeoutError – If
config.shard_timeout_secondsis exceeded.
- get_capabilities()[source]
Return PBS backend capabilities.
- Return type:
BackendCapabilities
- validate_resources()[source]
Check that
qsubis on PATH.- Raises:
RuntimeError – If
qsubis not found.- Return type:
- estimate_memory(n_data, n_params, n_chains)[source]
Estimate peak memory per PBS node (in GB) for a single shard.
Because each shard runs on a separate node, the result is per-node only. The estimate is conservative (upper-bound) to prevent OOM kills.
- Return type:
- cleanup(job_ids=None)[source]
Cancel active jobs and delete temporary files.
- submit_shard(shard_data, model_fn, config_dict, shard_id, seed)[source]
Serialize shard payload, write PBS script, and submit via qsub.
- Parameters:
shard_data (
dict[str,Any]) – Auxiliary shard-specific data (e.g.init_params).model_fn (
Callable[...,Any]) – Picklable NumPyro model function.config_dict (
dict[str,Any]) – Flat NUTS hyperparameter dict.shard_id (
int) – Zero-based shard index (drives file naming).seed (
int) – Integer random seed for this shard.
- Return type:
- Returns:
PBS job ID string.
- Raises:
RuntimeError – If
qsubfails.
- wait_for_jobs(job_ids, timeout=None)[source]
Poll
qstatuntil all jobs reach a terminal state.Jobs absent from
qstat(purged from the accounting database) are treated as complete; the.npzfile determines success or failure.
- run_shards(shards, model_fn, config, seeds)[source]
Submit and collect an explicit list of pre-partitioned shards.
Use this when the caller has already split the data and needs to attach per-shard payloads. For the standard pipeline use
run().- Parameters:
- Return type:
list[ShardResult]- Returns:
List of
ShardResultin shard order.- Raises:
ValueError – If
len(shards) != len(seeds).RuntimeError – If PBS is unavailable or submission fails.
TimeoutError – If
config.shard_timeout_secondsis exceeded.
Pjit Backend (Multi-Device)¶
Experimental backend that shards chains across jax.devices() via
pjit. Currently CPU-only because heterodyne ships no GPU support,
but the interface is device-agnostic.
Multi-device parallel MCMC backend using JAX sharding.
Distributes NUTS chains across multiple CPU devices using JAX’s modern
sharding API (jax.sharding), available since JAX 0.4.1. This
replaces the deprecated jax.experimental.pjit with the stable
jax.jit + sharding annotation approach. Heterodyne is CPU-only.
- class heterodyne.optimization.cmc.backends.pjit.PjitBackend[source]
Bases:
CMCBackendMulti-device parallel MCMC backend using JAX sharding.
Distributes NUTS chains across all available JAX devices. Each device runs a subset of the requested chains in parallel via NumPyro’s vectorized chain execution.
When only a single device is available, this backend transparently falls back to running all chains on that device (equivalent to the single-device
chain_method="parallel").This backend uses the modern
jax.shardingAPI (stable since JAX 0.4.1), not the deprecatedjax.experimental.pjit.- run(model, config, rng_key, init_params=None)[source]
Run NUTS sampling distributed across multiple devices.
Splits chains across available devices. Each device group runs independently, and results are gathered and concatenated.
- Parameters:
- Return type:
- Returns:
Dictionary mapping parameter names to flat sample arrays.
- Raises:
RuntimeError – If sampling fails on any device.
- get_capabilities()[source]
Return capabilities for multi-device parallel execution.
- Return type:
BackendCapabilities- Returns:
BackendCapabilities with sharding support flags.
- validate_resources()[source]
Check that multiple devices are available for sharding.
Logs a warning (but does not raise) when only one device is detected, since the backend can still function in single-device mode.
- Raises:
RuntimeError – If no JAX devices are available at all.
- Return type:
- estimate_memory(n_data, n_params, n_chains)[source]
Estimate peak memory per device in GB.
Each device holds a fraction of the chains. Memory per chain is approximately: n_params x n_data x 8 bytes (float64) for the likelihood evaluation, plus sample storage.
- heterodyne.optimization.cmc.backends.pjit.combine_shard_samples(shard_results)[source]
Combine posterior samples from multiple device shards.
Concatenates sample arrays along the first axis (samples dimension).
Persistent Worker Pool¶
Long-lived worker pool that amortises per-shard JAX JIT compilation overhead. Recommended for large analyses with many shards.
Worker pool backend for multi-shard CMC execution.
Spawns persistent workers that each run MCMC on assigned shards. Amortizes JAX/NumPyro initialization overhead across tasks.
- class heterodyne.optimization.cmc.backends.worker_pool.WorkerPoolBackend[source]
Bases:
objectPersistent worker pool for multi-shard CMC execution.
Distributes MCMC shards across a pool of worker processes. Each worker runs one chain per shard, and results are combined.
- __init__(n_workers=None)[source]
Initialize with optional worker count.
- property n_workers: int
- static should_use_pool(n_shards, n_workers)[source]
Check if pool execution is beneficial.
Homodyne parity: pool is used whenever there are at least 3 shards, regardless of the worker count. The pool’s startup cost is amortised across shards, not workers, so the gate only needs to reject trivially small shard counts.
- run(model, config, rng_key, init_params=None)[source]
Run MCMC via worker pool.
- Parameters:
- Return type:
- Returns:
Combined posterior samples from all workers.
- heterodyne.optimization.cmc.backends.worker_pool.should_use_persistent_pool(n_shards, n_workers)[source]
Homodyne-parity gate: persistent pool helps when there are >=3 shards.
- Return type:
- class heterodyne.optimization.cmc.backends.worker_pool.PersistentWorkerPool[source]
Bases:
objectPersistent process pool for CMC shard dispatch (homodyne parity).
Workers are spawned once, perform a one-time initialization (e.g. JAX warm-up / JIT priming) under control of
worker_init_fn, signal readiness, and then process tasks from a per-worker queue until aNonesentinel is received. The pool exposes a round-robinsubmit(), a blockingget_result(), and a deterministicshutdown()that drains task queues, joins processes, and terminates any that refuse to exit.The class supports the context-manager protocol;
__exit__callsshutdown().- Parameters:
n_workers (
int) – Number of persistent worker processes.worker_fn (
Callable[...,dict[str,Any] |None]) – Picklable module-level callable invoked per task. Signatureworker_fn(task: dict, **init_kwargs) -> dict | None. ReturningNoneskips putting a result on the shared queue, which is useful whenworker_fnalready manages its own result emission.worker_init_kwargs (
dict[str,Any] |None) – One-time kwargs forwarded to bothworker_init_fn(if provided) and everyworker_fncall. Must be picklable.worker_init_fn (
Callable[...,None] |None) – Optional one-time initializer with signatureinit_fn(worker_id: int, **init_kwargs) -> None. Use it for expensive setup like JAX backend selection or JIT pre-warming.startup_timeout (
float) – Maximum seconds to wait for all workers to signal readiness before continuing with whatever subset is ready.
- __init__(n_workers, worker_fn, worker_init_kwargs=None, worker_init_fn=None, startup_timeout=120.0)[source]
- property n_workers: int
- property result_queue: Queue
- get_result(timeout=300.0)[source]
Block until a result is available; raises
queue.Emptyon timeout.
- results_pending()[source]
True when the shared result queue has at least one entry.
- Return type: