"""Configuration for CMC (Consensus Monte Carlo) analysis.
This module defines CMCConfig, a comprehensive dataclass covering all aspects of
the heterodyne CMC pipeline: sharding strategy, backend selection, NUTS sampling
parameters, convergence validation thresholds, reparameterization, prior tempering,
shard combination, and run identification.
The heterodyne model has 14 free parameters (vs. 7 in homodyne). All auto-scaling
formulas account for this increased dimensionality.
"""
from __future__ import annotations
import math
import warnings
from dataclasses import dataclass, field
from typing import Any, ClassVar
from heterodyne.utils.logging import get_logger
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Module-level constants
# ---------------------------------------------------------------------------
#: Default number of model parameters for heterodyne (14-parameter model).
_N_PARAMS_HETERODYNE: int = 14
#: Minimum ratio of data points to parameters required per shard.
_MIN_POINTS_PER_PARAM_DEFAULT: int = 1_500
#: Reference shard size used for adaptive sample-count scaling (10 K points → full).
_REFERENCE_SHARD_SIZE: int = 10_000
#: Rule 12: NUTS warmup floor when ``dense_mass=True``. Dense mass-matrix
#: adaptation needs ≥ 100 steps per dimension; for the 14-parameter heterodyne
#: model, 1500 warmup steps is the minimum that yields healthy R-hat in regression.
DENSE_MASS_WARMUP_FLOOR: int = 1500
# Process-once warning gate so the fast_warmup advisory is logged at most once.
_FAST_WARMUP_WARNED: bool = False
[docs]
def effective_warmup_floor(
requested: int, *, dense_mass: bool, fast_warmup: bool = False
) -> int:
"""Return the safe warmup count, applying the Rule 12 floor when needed.
When ``dense_mass=True`` the NUTS dense mass-matrix adaptation requires
at least :data:`DENSE_MASS_WARMUP_FLOOR` steps; otherwise the requested
value is returned unchanged. ``fast_warmup=True`` opts out of the floor
and emits a one-shot warning — intended for CI fast-mode only.
Args:
requested: Caller-supplied (possibly scaled) warmup count.
dense_mass: Whether the sampler is configured with ``dense_mass=True``.
fast_warmup: Skip the Rule-12 floor (logs a one-shot warning).
Returns:
Warmup count to use.
"""
global _FAST_WARMUP_WARNED
if fast_warmup:
if not _FAST_WARMUP_WARNED:
logger.warning(
"fast_warmup=True bypasses Rule 12 warmup floor "
"(DENSE_MASS_WARMUP_FLOOR=%d) — for testing only, not production",
DENSE_MASS_WARMUP_FLOOR,
)
_FAST_WARMUP_WARNED = True
return max(1, int(requested))
if dense_mass and int(requested) < DENSE_MASS_WARMUP_FLOOR:
return DENSE_MASS_WARMUP_FLOOR
return max(1, int(requested))
# Valid string literals for each enumerated field.
_VALID_ENABLE: frozenset[str] = frozenset({"auto", "always", "never"})
_VALID_PER_ANGLE_MODE: frozenset[str] = frozenset(
{"auto", "constant", "constant_averaged", "individual"}
)
_VALID_SHARDING_STRATEGY: frozenset[str] = frozenset(
{"stratified", "random", "contiguous"}
)
_VALID_BACKEND_NAME: frozenset[str] = frozenset(
# "jax" is a legacy alias for "multiprocessing"; routed in select_backend.
# "jit" is the new name for "pjit" (device/config.CMCBackend.JIT); both accepted.
# "slurm" has no native backend — select_backend falls back to MP.
{"auto", "multiprocessing", "pjit", "jit", "cpu", "pbs", "slurm", "jax"}
)
_VALID_CHAIN_METHOD: frozenset[str] = frozenset(
{"parallel", "sequential", "vectorized"}
)
_VALID_INIT_STRATEGY: frozenset[str] = frozenset(
{"init_to_median", "init_to_sample", "init_to_value"}
)
_VALID_COMBINATION_METHOD: frozenset[str] = frozenset(
# "auto" → resolved to robust_consensus_mc in _combine_shard_posteriors.
{
"auto",
"consensus_mc",
"robust_consensus_mc",
"weighted_gaussian",
"simple_average",
}
)
# ---------------------------------------------------------------------------
# CMCConfig
# ---------------------------------------------------------------------------
[docs]
@dataclass
class CMCConfig:
"""Comprehensive configuration for Consensus Monte Carlo (CMC) analysis.
CMC splits a large dataset into K shards, runs NUTS independently on each
shard, then combines the resulting posteriors using a consensus algorithm.
This dataclass controls every knob across the full pipeline.
Parameters are grouped into logical sections, matching the structure of the
``to_dict`` / ``from_dict`` serialization format:
- **enable** — master on/off switch and dataset-size gate.
- **per_angle** — how to handle the phi (angle) dimension.
- **sharding** — shard count, strategy, and size bounds.
- **backend_config** — worker backend and checkpoint settings.
- **per_shard_mcmc** — NUTS hyper-parameters and adaptive scaling.
- **validation** — convergence thresholds and abort conditions.
- **nlsq** — NLSQ warm-start and prior-width configuration.
- **prior_tempering** — scale priors by 1/K for shard consistency.
- **combination** — posterior combination algorithm and success criteria.
- **timeout** — per-shard and heartbeat time limits.
- **reparameterization** — parameter transforms and bimodality guards.
- **run_id** — optional identifier for checkpoint namespacing.
Attributes
----------
enable:
Master switch. ``"auto"`` enables CMC when ``n_points >= min_points_for_cmc``.
``True`` / ``"always"`` forces CMC regardless of dataset size.
``False`` / ``"never"`` disables CMC entirely.
min_points_for_cmc:
Minimum number of data points required before CMC is activated under
``enable="auto"``. Below this threshold the pipeline falls back to full
NUTS.
per_angle_mode:
Strategy for handling the angle (phi) dimension.
``"auto"`` selects automatically based on ``n_phi`` and
``constant_scaling_threshold``.
constant_scaling_threshold:
Minimum number of phi angles required before switching from
``"constant"`` to ``"individual"`` mode when ``per_angle_mode="auto"``.
sharding_strategy:
How to partition data across shards: ``"stratified"`` preserves angle
distributions, ``"random"`` shuffles globally, ``"contiguous"`` uses
contiguous memory blocks.
num_shards:
Number of shards ``K``. ``"auto"`` derives ``K`` from dataset size,
phi count, and ``min_points_per_shard`` / ``min_points_per_param``.
max_points_per_shard:
Upper bound on shard size. ``"auto"`` disables the cap.
min_points_per_shard:
Lower bound on shard size; prevents degenerate under-determined shards.
min_points_per_param:
Minimum ratio of points-to-parameters per shard (heterodyne default: 14).
backend_name:
Worker backend. ``"auto"`` selects based on available CPU devices
and core count.
enable_checkpoints:
Persist intermediate shard results to ``checkpoint_dir``.
checkpoint_dir:
Directory for shard checkpoint files.
chain_method:
Whether to run chains in ``"parallel"`` or ``"sequential"`` order within
each shard worker.
num_warmup:
Number of NUTS warm-up (burn-in) steps per chain. Default 1500
provides ~100 steps per parameter for the 14-parameter heterodyne
model when ``dense_mass=True``; lower values leave mass-matrix
adaptation incomplete and produce high R-hat / divergence storms.
num_samples:
Number of posterior draws per chain after warm-up.
num_chains:
Number of independent MCMC chains per shard.
target_accept_prob:
Target acceptance probability for the dual-averaging NUTS step-size
adaptation (must be in ``[0.5, 0.99]``). Default 0.90 keeps step
sizes small enough to traverse the (D0, alpha) funnel without
divergence cascades; reduce only if you understand the geometry.
max_tree_depth:
Maximum binary tree depth for NUTS leapfrog integration.
seed:
Base random seed for deterministic reproducibility.
dense_mass:
Use a dense (full-covariance) mass matrix. Expensive but more
accurate for highly correlated posteriors.
init_strategy:
NUTS initialisation strategy.
adaptive_sampling:
Scale ``num_warmup`` / ``num_samples`` down proportionally when shard
size is below ``_REFERENCE_SHARD_SIZE``.
min_warmup:
Floor on adaptive warm-up count.
min_samples:
Floor on adaptive sample count.
max_r_hat:
Maximum acceptable Gelman-Rubin statistic; chains with ``R-hat >
max_r_hat`` are flagged as not converged.
min_ess:
Minimum effective sample size per parameter.
min_bfmi:
Minimum Bayesian Fraction of Missing Information (energy diagnostic).
max_divergence_rate:
Maximum fraction of divergent transitions before a shard is rejected.
require_nlsq_warmstart:
Abort if an NLSQ warm-start was requested but unavailable.
allow_degenerate_warmstart:
When ``False`` (default) an early ``RuntimeError`` is raised before
dispatching any shards if the warm-start is in a regime that is
guaranteed to cause 100% shard failure (f0 < 0.10 or
alpha_sample < -1.5 — the het_bb97531f failure mode). Set to
``True`` to bypass the abort and let NUTS attempt the run anyway;
useful when you want to see exactly how bad the posteriors are or
when you have increased ``num_warmup`` substantially (≥ 2000).
max_parameter_cv:
Maximum allowed coefficient of variation across chains for any
parameter; guards against pathological multi-modal posteriors.
heterogeneity_abort:
Abort the entire CMC run if shards produce incompatible posteriors
(detected via KL divergence or parameter-CV checks).
use_nlsq_warmstart:
Initialise each shard's NUTS chains from the NLSQ MAP estimate.
use_nlsq_informed_priors:
Centre Gaussian priors on NLSQ estimates scaled by
``nlsq_prior_width_factor``.
nlsq_prior_width_factor:
Scale factor applied to NLSQ parameter uncertainties when constructing
informed priors.
prior_tempering:
Divide log-prior by ``K`` (number of shards) so that the combined
posterior approximates the full-data prior exactly.
combination_method:
Algorithm used to combine shard posteriors.
min_success_rate:
Minimum fraction of shards that must converge; run fails below this.
min_success_rate_warning:
Fraction below which a warning is emitted even if the run succeeds.
per_shard_timeout:
Wall-clock seconds allowed per shard before it is cancelled.
heartbeat_timeout:
Seconds of silence from a worker before it is declared dead.
use_reparam:
Apply parameter reparameterisations (e.g. log-transforms) in NumPyro.
reparameterization_d_total:
Reparameterise ``d_total = d_fast + d_slow`` as an unconstrained sum.
reparameterization_log_gamma:
Reparameterise ``gamma`` on a log scale to enforce positivity.
bimodal_min_weight:
Minimum mixture weight for the minor mode in bimodal posteriors; below
this the minor mode is discarded.
bimodal_min_separation:
Minimum normalised distance between modes to declare bimodality.
run_id:
Optional string identifier for this CMC run, used in checkpoint paths
and log messages.
"""
# ------------------------------------------------------------------
# 1. Enable & dataset-size gate
# ------------------------------------------------------------------
enable: bool | str = "auto"
min_points_for_cmc: int = 100_000
# ------------------------------------------------------------------
# 2. Per-angle mode
# ------------------------------------------------------------------
per_angle_mode: str = "auto"
constant_scaling_threshold: int = 3
# ------------------------------------------------------------------
# 3. Sharding
# ------------------------------------------------------------------
sharding_strategy: str = "random"
num_shards: int | str = "auto"
max_points_per_shard: int | str = "auto"
min_points_per_shard: int = 10_000
min_points_per_param: int = _MIN_POINTS_PER_PARAM_DEFAULT
# ------------------------------------------------------------------
# 4. Backend
# ------------------------------------------------------------------
backend_name: str = "auto"
enable_checkpoints: bool = True
checkpoint_dir: str = "./checkpoints/cmc"
chain_method: str = "parallel"
# ------------------------------------------------------------------
# 5. Sampling
# ------------------------------------------------------------------
num_warmup: int = 1500
num_samples: int = 1500
num_chains: int = 4
target_accept_prob: float = 0.90
max_tree_depth: int = 10
seed: int = 42
dense_mass: bool = True
init_strategy: str = "init_to_median"
adaptive_sampling: bool = True
min_warmup: int = 100
min_samples: int = 200
#: Rule 12 escape hatch: skip the dense-mass warmup floor (1500 steps).
#: Intended for CI fast-mode and pytest fixtures only; NOT for production.
#: When True, all warmup calculations bypass :func:`effective_warmup_floor`.
fast_warmup: bool = False
# ------------------------------------------------------------------
# 6. Validation thresholds
# ------------------------------------------------------------------
max_r_hat: float = 1.1
min_ess: int = 400
min_bfmi: float = 0.3
max_divergence_rate: float = 0.10
require_nlsq_warmstart: bool = False
allow_degenerate_warmstart: bool = False
max_parameter_cv: float = 1.0
heterogeneity_abort: bool = True
# ------------------------------------------------------------------
# 7. NLSQ-informed priors
# ------------------------------------------------------------------
use_nlsq_warmstart: bool = True
use_nlsq_informed_priors: bool = True
nlsq_prior_width_factor: float = 2.0
#: Codex S1: integrate ``build_log_space_priors`` into ``build_default_priors``.
#: When True (default), parameters flagged ``log_space=True`` in the parameter
#: registry (currently D0_ref, D0_sample, v0) are sampled with LogNormal priors
#: instead of TruncatedNormal — better mass-matrix conditioning for prefactors
#: that span several orders of magnitude. When False, the registry's
#: ``log_space`` flag is ignored and all parameters use TruncatedNormal.
#: The reparameterized path (``use_reparam=True``) is unaffected — it samples
#: log_X_at_tref directly and does not reach this code path.
use_log_space_priors: bool = True
# ------------------------------------------------------------------
# 8. Prior tempering
# ------------------------------------------------------------------
prior_tempering: bool = True
# ------------------------------------------------------------------
# 9. Combination
# ------------------------------------------------------------------
combination_method: str = "robust_consensus_mc"
min_success_rate: float = 0.90
min_success_rate_warning: float = 0.80
# ------------------------------------------------------------------
# 10. Timeout
# ------------------------------------------------------------------
per_shard_timeout: int = 7200
heartbeat_timeout: int = 600
# ------------------------------------------------------------------
# 11. JAX profiling
# ------------------------------------------------------------------
enable_jax_profiling: bool = False
jax_profile_dir: str = "./profiles/jax"
# ------------------------------------------------------------------
# 12. Reparameterization
# ------------------------------------------------------------------
use_reparam: bool = True
reparameterization_d_total: bool = True
reparameterization_log_gamma: bool = True
bimodal_min_weight: float = 0.2
bimodal_min_separation: float = 0.5
# ------------------------------------------------------------------
# 13. Run identification
# ------------------------------------------------------------------
run_id: str | None = None
# ------------------------------------------------------------------
# Private: accumulated validation errors (not shown in repr)
# ------------------------------------------------------------------
_validation_errors: list[str] = field(default_factory=list, repr=False)
# ==================================================================
# Backward-compat aliases (deferred task #38 — homodyne parity rename)
# ==================================================================
#
# Three dataclass fields use homodyne-canonical names; legacy aliases
# remain for external YAML configs and downstream code that has not
# yet migrated. Reads of the legacy names route to the canonical
# fields with a single ``DeprecationWarning`` per access.
#
# legacy alias -> canonical field
# target_accept -> target_accept_prob
# r_hat_threshold -> max_r_hat
# prior_width_factor -> nlsq_prior_width_factor
#
# All in-tree call sites use the canonical names directly, so the
# shim adds zero runtime cost on the hot NUTS path. Old names will
# be removed in a future release.
_LEGACY_ALIASES: ClassVar[dict[str, str]] = {
"target_accept": "target_accept_prob",
"r_hat_threshold": "max_r_hat",
"prior_width_factor": "nlsq_prior_width_factor",
}
def __getattr__(self, name: str) -> object:
# __getattr__ only runs when normal lookup fails, so dataclass
# fields are unaffected. We only intercept the 3 legacy aliases.
aliases = type(self).__dict__.get("_LEGACY_ALIASES") or self._LEGACY_ALIASES
if name in aliases:
import warnings
new_name = aliases[name]
warnings.warn(
f"CMCConfig.{name} is deprecated; use CMCConfig.{new_name} instead.",
DeprecationWarning,
stacklevel=2,
)
return object.__getattribute__(self, new_name)
raise AttributeError(f"'CMCConfig' object has no attribute {name!r}")
# ==================================================================
# Post-init
# ==================================================================
[docs]
def __post_init__(self) -> None:
"""Normalise string-valued enable flag and log construction."""
# Coerce boolean True/False to canonical string forms.
if self.enable is True:
self.enable = "always"
elif self.enable is False:
self.enable = "never"
# Coerce removed "gpu" backend to "auto" (heterodyne is CPU-only)
if self.backend_name == "gpu":
import warnings
warnings.warn(
"backend_name='gpu' is not supported; heterodyne is CPU-only. "
"Falling back to 'auto'.",
DeprecationWarning,
stacklevel=2,
)
self.backend_name = "auto"
logger.debug(
"CMCConfig constructed: enable=%s num_shards=%s backend=%s",
self.enable,
self.num_shards,
self.backend_name,
)
# ==================================================================
# Validation
# ==================================================================
[docs]
def validate(self) -> list[str]:
"""Run comprehensive field validation and return a list of error strings.
Returns
-------
list[str]
Empty list when the configuration is valid; one entry per violation
otherwise. Does *not* raise — callers decide how to handle errors.
"""
errors: list[str] = []
# ---- enable ---------------------------------------------------
enable_str = str(self.enable).lower()
if enable_str not in _VALID_ENABLE:
errors.append(
f"enable={self.enable!r} is not valid; "
f"must be one of {sorted(_VALID_ENABLE)!r} or a bool."
)
# ---- min_points_for_cmc ---------------------------------------
if self.min_points_for_cmc < 1:
errors.append(f"min_points_for_cmc={self.min_points_for_cmc} must be >= 1.")
# ---- per_angle_mode -------------------------------------------
if self.per_angle_mode not in _VALID_PER_ANGLE_MODE:
errors.append(
f"per_angle_mode={self.per_angle_mode!r} is not valid; "
f"must be one of {sorted(_VALID_PER_ANGLE_MODE)!r}."
)
if self.constant_scaling_threshold < 1:
errors.append(
f"constant_scaling_threshold={self.constant_scaling_threshold} "
"must be >= 1."
)
# ---- sharding -------------------------------------------------
if self.sharding_strategy not in _VALID_SHARDING_STRATEGY:
errors.append(
f"sharding_strategy={self.sharding_strategy!r} is not valid; "
f"must be one of {sorted(_VALID_SHARDING_STRATEGY)!r}."
)
if isinstance(self.num_shards, int) and self.num_shards < 1:
errors.append(
f"num_shards={self.num_shards} must be >= 1 when set explicitly."
)
elif isinstance(self.num_shards, str) and self.num_shards != "auto":
errors.append(f"num_shards={self.num_shards!r} must be an int or 'auto'.")
if isinstance(self.max_points_per_shard, int):
if self.max_points_per_shard < 1:
errors.append(
f"max_points_per_shard={self.max_points_per_shard} must be >= 1."
)
elif (
isinstance(self.max_points_per_shard, str)
and self.max_points_per_shard != "auto"
):
errors.append(
f"max_points_per_shard={self.max_points_per_shard!r} must be an int "
"or 'auto'."
)
if self.min_points_per_shard < 1:
errors.append(
f"min_points_per_shard={self.min_points_per_shard} must be >= 1."
)
if self.min_points_per_param < 1:
errors.append(
f"min_points_per_param={self.min_points_per_param} must be >= 1."
)
# Cross-check: if both are explicit integers, min <= max.
if isinstance(self.min_points_per_shard, int) and isinstance(
self.max_points_per_shard, int
):
if self.min_points_per_shard > self.max_points_per_shard:
errors.append(
f"min_points_per_shard={self.min_points_per_shard} exceeds "
f"max_points_per_shard={self.max_points_per_shard}."
)
# ---- backend --------------------------------------------------
if self.backend_name not in _VALID_BACKEND_NAME:
errors.append(
f"backend_name={self.backend_name!r} is not valid; "
f"must be one of {sorted(_VALID_BACKEND_NAME)!r}."
)
if self.chain_method not in _VALID_CHAIN_METHOD:
errors.append(
f"chain_method={self.chain_method!r} is not valid; "
f"must be one of {sorted(_VALID_CHAIN_METHOD)!r}."
)
# ---- sampling -------------------------------------------------
if self.num_warmup < 1:
errors.append(f"num_warmup={self.num_warmup} must be >= 1.")
# Rule 12 — Dense-mass NUTS needs ≥ DENSE_MASS_WARMUP_FLOOR steps.
# Honour fast_warmup as the explicit opt-out for CI / pytest fast mode.
if (
self.dense_mass
and not self.fast_warmup
and 1 <= self.num_warmup < DENSE_MASS_WARMUP_FLOOR
):
errors.append(
f"num_warmup={self.num_warmup} is below the Rule 12 floor "
f"({DENSE_MASS_WARMUP_FLOOR}) required when dense_mass=True. "
"Either raise num_warmup, set dense_mass=False, or set "
"fast_warmup=True for CI fast-mode (not production)."
)
if self.num_samples < 1:
errors.append(f"num_samples={self.num_samples} must be >= 1.")
if self.num_chains < 1:
errors.append(f"num_chains={self.num_chains} must be >= 1.")
if not (0.5 <= self.target_accept_prob <= 0.99):
errors.append(
f"target_accept_prob={self.target_accept_prob} must be in [0.5, 0.99]."
)
if self.max_tree_depth < 1:
errors.append(f"max_tree_depth={self.max_tree_depth} must be >= 1.")
if self.init_strategy not in _VALID_INIT_STRATEGY:
errors.append(
f"init_strategy={self.init_strategy!r} is not valid; "
f"must be one of {sorted(_VALID_INIT_STRATEGY)!r}."
)
if self.min_warmup < 1:
errors.append(f"min_warmup={self.min_warmup} must be >= 1.")
if self.min_samples < 1:
errors.append(f"min_samples={self.min_samples} must be >= 1.")
if self.min_warmup > self.num_warmup:
errors.append(
f"min_warmup={self.min_warmup} must be <= num_warmup={self.num_warmup}."
)
if self.min_samples > self.num_samples:
errors.append(
f"min_samples={self.min_samples} must be <= num_samples={self.num_samples}."
)
# ---- validation thresholds ------------------------------------
if self.max_r_hat <= 1.0:
errors.append(
f"max_r_hat={self.max_r_hat} must be > 1.0 "
"(R-hat is always >= 1 by definition; threshold must exceed 1.0)."
)
if self.min_ess < 1:
errors.append(f"min_ess={self.min_ess} must be >= 1.")
if not (0.0 < self.min_bfmi <= 1.0):
errors.append(f"min_bfmi={self.min_bfmi} must be in (0, 1].")
if not (0.0 <= self.max_divergence_rate <= 1.0):
errors.append(
f"max_divergence_rate={self.max_divergence_rate} must be in [0, 1]."
)
if self.max_parameter_cv <= 0.0:
errors.append(f"max_parameter_cv={self.max_parameter_cv} must be > 0.")
# ---- NLSQ priors ----------------------------------------------
if self.nlsq_prior_width_factor <= 0.0:
errors.append(
f"nlsq_prior_width_factor={self.nlsq_prior_width_factor} must be > 0."
)
# ---- combination ----------------------------------------------
if self.combination_method not in _VALID_COMBINATION_METHOD:
errors.append(
f"combination_method={self.combination_method!r} is not valid; "
f"must be one of {sorted(_VALID_COMBINATION_METHOD)!r}."
)
if not (0.0 < self.min_success_rate <= 1.0):
errors.append(
f"min_success_rate={self.min_success_rate} must be in (0, 1]."
)
if not (0.0 < self.min_success_rate_warning <= 1.0):
errors.append(
f"min_success_rate_warning={self.min_success_rate_warning} "
"must be in (0, 1]."
)
if self.min_success_rate_warning > self.min_success_rate:
logger.warning(
"min_success_rate_warning=%.2f > min_success_rate=%.2f — "
"the warning threshold will never trigger (abort fires first).",
self.min_success_rate_warning,
self.min_success_rate,
)
# ---- timeout --------------------------------------------------
if self.per_shard_timeout < 1:
errors.append(
f"per_shard_timeout={self.per_shard_timeout} must be >= 1 second."
)
if self.heartbeat_timeout < 1:
errors.append(
f"heartbeat_timeout={self.heartbeat_timeout} must be >= 1 second."
)
if self.heartbeat_timeout > self.per_shard_timeout:
errors.append(
f"heartbeat_timeout={self.heartbeat_timeout} must be <= "
f"per_shard_timeout={self.per_shard_timeout}."
)
# ---- reparameterization ---------------------------------------
if not (0.0 < self.bimodal_min_weight < 0.5):
errors.append(
f"bimodal_min_weight={self.bimodal_min_weight} must be in (0, 0.5); "
"it represents the minor mixture component weight."
)
if self.bimodal_min_separation <= 0.0:
errors.append(
f"bimodal_min_separation={self.bimodal_min_separation} must be > 0."
)
# Cache for is_valid() fast path.
self._validation_errors = errors
if errors:
logger.warning(
"CMCConfig validation found %d error(s): %s",
len(errors),
"; ".join(errors),
)
return errors
[docs]
def is_valid(self) -> bool:
"""Return True if the configuration passes all validation checks.
Equivalent to ``len(self.validate()) == 0``.
Returns
-------
bool
"""
return len(self.validate()) == 0
# ==================================================================
# Runtime queries
# ==================================================================
[docs]
def should_enable_cmc(
self,
n_points: int,
analysis_mode: str | None = None,
) -> bool:
"""Decide whether to run CMC given the dataset size.
Parameters
----------
n_points:
Total number of data points in the dataset.
analysis_mode:
Optional homodyne-parity kwarg. Accepted and ignored —
heterodyne does not gate CMC on the analyzer mode. Present
so that callers ported from homodyne continue to work.
Returns
-------
bool
``True`` if CMC should run for this dataset.
"""
del analysis_mode # homodyne-parity kwarg, intentionally unused
enable_str = str(self.enable).lower()
if enable_str in {"always", "true"}:
logger.debug("CMC enabled unconditionally (enable=%r).", self.enable)
return True
if enable_str in {"never", "false"}:
logger.debug("CMC disabled unconditionally (enable=%r).", self.enable)
return False
# "auto" branch — gate on minimum dataset size.
if n_points >= self.min_points_for_cmc:
logger.debug(
"CMC auto-enabled: n_points=%d >= min_points_for_cmc=%d.",
n_points,
self.min_points_for_cmc,
)
return True
logger.info(
"CMC auto-disabled: n_points=%d < min_points_for_cmc=%d.",
n_points,
self.min_points_for_cmc,
)
return False
[docs]
def get_num_shards(
self,
n_points: int,
n_phi: int,
n_params: int = _N_PARAMS_HETERODYNE,
) -> int:
"""Compute the number of shards K for a given dataset.
When ``num_shards`` is an explicit integer it is returned directly
(clamped to >= 1). When ``"auto"``, K is derived as:
1. Start from ``max(n_phi, 2)`` — at least as many shards as phi angles.
2. Apply the ``min_points_per_shard`` lower bound:
``K <= n_points // min_points_per_shard``.
3. Apply the ``min_points_per_param`` constraint:
``K <= n_points // (n_params * min_points_per_param)``.
4. Apply the ``max_points_per_shard`` upper bound when set:
``K >= ceil(n_points / max_points_per_shard)``.
5. Clamp to ``[1, n_points]``.
Parameters
----------
n_points:
Total number of data points.
n_phi:
Number of distinct phi (azimuthal angle) bins.
n_params:
Number of free model parameters (default: 14 for heterodyne).
Returns
-------
int
Number of shards K >= 1.
"""
if isinstance(self.num_shards, int):
k = max(1, self.num_shards)
logger.debug("Using explicit num_shards=%d.", k)
return k
# Auto computation.
if n_points < 1:
logger.warning(
"get_num_shards called with n_points=%d; returning K=1.", n_points
)
return 1
# Lower bound from phi structure: at least one shard per angle group.
k_from_phi = max(n_phi, 2)
# Upper bound from min shard size.
k_max_from_min_size = n_points // max(self.min_points_per_shard, 1)
# Upper bound from points-per-parameter constraint.
min_shard_size_for_params = n_params * self.min_points_per_param
k_max_from_params = n_points // max(min_shard_size_for_params, 1)
k_upper = min(k_max_from_min_size, k_max_from_params)
# Lower bound from max_points_per_shard cap.
if isinstance(self.max_points_per_shard, int):
k_min_from_max_size = math.ceil(
n_points / max(self.max_points_per_shard, 1)
)
else:
# "auto": target shard size = min_shard_size_for_params so NUTS
# O(n) cost stays bounded. Without this floor K collapses to
# k_from_phi (typically 2) and shards become 100K+ points.
k_min_from_max_size = math.ceil(
n_points / max(min_shard_size_for_params, 1)
)
# Combine: start from phi suggestion, respect all bounds.
k = max(k_from_phi, k_min_from_max_size)
k = min(k, max(k_upper, 1))
k = max(k, 1)
logger.debug(
"Auto num_shards: n_points=%d n_phi=%d n_params=%d → K=%d "
"(k_from_phi=%d k_max_size=%d k_max_params=%d k_min_cap=%d).",
n_points,
n_phi,
n_params,
k,
k_from_phi,
k_max_from_min_size,
k_max_from_params,
k_min_from_max_size,
)
return k
[docs]
def get_adaptive_sample_counts(
self,
shard_size: int,
n_params: int = _N_PARAMS_HETERODYNE,
) -> tuple[int, int]:
"""Scale warmup and sample counts for a given shard size.
When ``adaptive_sampling=False`` the configured ``num_warmup`` and
``num_samples`` are returned unchanged.
The scaling law is:
.. code-block:: text
scale = clamp(shard_size / reference_size, 0, 1)
warmup = max(min_warmup, round(num_warmup * scale))
samples = max(min_samples, round(num_samples * scale))
where ``reference_size = _REFERENCE_SHARD_SIZE`` (10 000 points) is the
shard size at which the full configured counts are used. Larger shards
are *not* scaled up beyond the configured maximum; the formula saturates
at ``scale = 1``.
A secondary check ensures a minimum of ``n_params`` samples are drawn
(ESS cannot exceed ``num_samples * num_chains``).
Parameters
----------
shard_size:
Number of data points in this shard.
n_params:
Number of model parameters (default: 14).
Returns
-------
tuple[int, int]
``(warmup, samples)`` after adaptive scaling.
"""
if not self.adaptive_sampling:
return self.num_warmup, self.num_samples
# Homodyne CMC parity (config.py:700-760): scale linearly up to a
# reference shard size, then enforce a *parameter-aware* floor that
# is itself capped at the user-configured maximum so adaptive scaling
# only ever reduces work, never inflates it.
scale = min(1.0, max(0.0, shard_size / _REFERENCE_SHARD_SIZE))
scaled_warmup = int(self.num_warmup * scale)
scaled_samples = int(self.num_samples * scale)
# Parameter-aware floors capped at the configured ceiling.
min_warmup_for_params = min(
max(self.min_warmup, 20 * n_params), self.num_warmup
)
min_samples_for_params = min(
max(self.min_samples, 50 * n_params), self.num_samples
)
warmup = max(min_warmup_for_params, scaled_warmup)
samples = max(min_samples_for_params, scaled_samples)
# Rule 12: respect the dense-mass warmup floor regardless of shard size.
warmup = effective_warmup_floor(
warmup, dense_mass=self.dense_mass, fast_warmup=self.fast_warmup
)
if warmup != self.num_warmup or samples != self.num_samples:
logger.debug(
"Adaptive sampling: shard_size=%d (scale=%.3f, n_params=%d) → "
"warmup=%d (was %d), samples=%d (was %d).",
shard_size,
scale,
n_params,
warmup,
self.num_warmup,
samples,
self.num_samples,
)
return warmup, samples
[docs]
def get_effective_per_angle_mode(
self,
n_phi: int,
nlsq_per_angle_mode: str | None = None,
has_nlsq_warmstart: bool = False,
) -> str:
"""Resolve the effective per-angle mode for a concrete dataset.
Mirrors ``homodyne/optimization/cmc/config.py::get_effective_per_angle_mode``
so CMC/NLSQ parameterization stays in lock-step across both packages.
Resolution logic (priority order):
1. If ``nlsq_per_angle_mode`` is provided, mirror it for CMC↔NLSQ
parameterization parity, regardless of ``self.per_angle_mode``.
If both sides are ``"auto"`` AND ``has_nlsq_warmstart`` is True,
promote to ``"constant_averaged"`` so scaling is fixed (fewer
sampled params, less heterogeneity across shards).
2. Else if ``self.per_angle_mode != "auto"`` → return it directly.
3. Else (auto, no NLSQ): ``n_phi >= constant_scaling_threshold``
→ ``"auto"`` (sampled averaged); otherwise → ``"individual"``.
Parameters
----------
n_phi:
Number of distinct phi (azimuthal angle) bins in the dataset.
nlsq_per_angle_mode:
The per-angle mode resolved by the preceding NLSQ fit, if any.
When provided this overrides the configured mode for parity.
has_nlsq_warmstart:
Whether a valid NLSQ warm-start is available for this run.
Returns
-------
str
Effective mode: ``"auto"``, ``"constant"``, ``"constant_averaged"``,
or ``"individual"``.
"""
# Priority 1: NLSQ override (regardless of self.per_angle_mode).
if nlsq_per_angle_mode is not None:
if (
has_nlsq_warmstart
and nlsq_per_angle_mode == "auto"
and self.per_angle_mode == "auto"
):
logger.info(
"CMC per-angle mode: auto -> constant_averaged "
"(NLSQ warm-start present, fixing scaling for stability)"
)
return "constant_averaged"
logger.debug(
"Per-angle mode mirrored from NLSQ warm-start: %r.",
nlsq_per_angle_mode,
)
return nlsq_per_angle_mode
# Priority 2: explicit non-auto configured mode.
if self.per_angle_mode != "auto":
logger.debug("Per-angle mode fixed to %r (not auto).", self.per_angle_mode)
return self.per_angle_mode
# Priority 3: auto resolution from n_phi.
if n_phi >= self.constant_scaling_threshold:
resolved = "auto"
else:
resolved = "individual"
logger.debug(
"Per-angle mode auto-resolved to %r (n_phi=%d threshold=%d).",
resolved,
n_phi,
self.constant_scaling_threshold,
)
return resolved
# ==================================================================
# Serialization
# ==================================================================
[docs]
def to_dict(self) -> dict[str, Any]:
"""Serialise the configuration to a nested dictionary.
The returned structure uses the same section names expected by
``from_dict``, making round-trips lossless.
Returns
-------
dict[str, Any]
Nested dictionary representation of the config.
"""
return {
"enable": self.enable,
"min_points_for_cmc": self.min_points_for_cmc,
"run_id": self.run_id,
"per_angle": {
"per_angle_mode": self.per_angle_mode,
"constant_scaling_threshold": self.constant_scaling_threshold,
},
"sharding": {
"sharding_strategy": self.sharding_strategy,
"num_shards": self.num_shards,
"max_points_per_shard": self.max_points_per_shard,
"min_points_per_shard": self.min_points_per_shard,
"min_points_per_param": self.min_points_per_param,
},
"backend_config": {
"backend_name": self.backend_name,
"enable_checkpoints": self.enable_checkpoints,
"checkpoint_dir": self.checkpoint_dir,
"chain_method": self.chain_method,
},
"per_shard_mcmc": {
"num_warmup": self.num_warmup,
"num_samples": self.num_samples,
"num_chains": self.num_chains,
"target_accept_prob": self.target_accept_prob,
"max_tree_depth": self.max_tree_depth,
"seed": self.seed,
"dense_mass": self.dense_mass,
"init_strategy": self.init_strategy,
"adaptive_sampling": self.adaptive_sampling,
"min_warmup": self.min_warmup,
"min_samples": self.min_samples,
"fast_warmup": self.fast_warmup,
},
"validation": {
"max_r_hat": self.max_r_hat,
"min_ess": self.min_ess,
"min_bfmi": self.min_bfmi,
"max_divergence_rate": self.max_divergence_rate,
"require_nlsq_warmstart": self.require_nlsq_warmstart,
"max_parameter_cv": self.max_parameter_cv,
"heterogeneity_abort": self.heterogeneity_abort,
"allow_degenerate_warmstart": self.allow_degenerate_warmstart,
},
"nlsq": {
"use_nlsq_warmstart": self.use_nlsq_warmstart,
"use_nlsq_informed_priors": self.use_nlsq_informed_priors,
"nlsq_prior_width_factor": self.nlsq_prior_width_factor,
"use_log_space_priors": self.use_log_space_priors,
},
"profiling": {
"enable_jax_profiling": self.enable_jax_profiling,
"jax_profile_dir": self.jax_profile_dir,
},
"prior_tempering": self.prior_tempering,
"combination": {
"combination_method": self.combination_method,
"min_success_rate": self.min_success_rate,
"min_success_rate_warning": self.min_success_rate_warning,
},
"timeout": {
"per_shard_timeout": self.per_shard_timeout,
"heartbeat_timeout": self.heartbeat_timeout,
},
"reparameterization": {
"use_reparam": self.use_reparam,
"reparameterization_d_total": self.reparameterization_d_total,
"reparameterization_log_gamma": self.reparameterization_log_gamma,
"bimodal_min_weight": self.bimodal_min_weight,
"bimodal_min_separation": self.bimodal_min_separation,
},
}
[docs]
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> CMCConfig:
"""Construct a CMCConfig from a (possibly nested) dictionary.
Recognised top-level keys and sections:
- ``enable``, ``min_points_for_cmc``, ``run_id`` — top-level scalars.
- ``prior_tempering`` — top-level scalar.
- ``per_angle`` — maps to ``per_angle_mode``, ``constant_scaling_threshold``.
- ``sharding`` — maps to the five sharding fields.
- ``backend_config`` — maps to the four backend fields.
- ``per_shard_mcmc`` — maps to the eleven sampling fields.
- ``validation`` — maps to the seven validation-threshold fields.
- ``nlsq`` — maps to the three NLSQ-prior fields.
- ``combination`` — maps to the three combination fields.
- ``timeout`` — maps to the two timeout fields.
- ``reparameterization`` — maps to the five reparam fields.
Flat (non-nested) dictionaries are also accepted for backward
compatibility: any key that matches a field name directly is used as-is.
Unrecognised top-level keys emit a ``warnings.warn`` so that
configuration typos surface immediately.
Parameters
----------
config_dict:
Parsed YAML / JSON dictionary.
Returns
-------
CMCConfig
Fully constructed configuration instance.
"""
kwargs: dict[str, Any] = {}
# --- Helpers ---------------------------------------------------
def _extract_section(section_key: str) -> dict[str, Any]:
val = config_dict.get(section_key, {})
if not isinstance(val, dict):
warnings.warn(
f"CMCConfig.from_dict: section {section_key!r} is not a dict "
f"(got {type(val).__name__!r}); ignoring.",
stacklevel=3,
)
return {}
return val
def _pick(
target_field: str,
source: dict[str, Any],
source_key: str | None = None,
) -> None:
key = source_key if source_key is not None else target_field
if key in source:
kwargs[target_field] = source[key]
# --- Top-level scalars -----------------------------------------
_pick("enable", config_dict)
_pick("min_points_for_cmc", config_dict)
_pick("run_id", config_dict)
_pick("prior_tempering", config_dict)
# --- per_angle section -----------------------------------------
per_angle = _extract_section("per_angle")
_pick("per_angle_mode", per_angle)
_pick("constant_scaling_threshold", per_angle)
# Flat fallback
_pick("per_angle_mode", config_dict)
_pick("constant_scaling_threshold", config_dict)
# --- sharding section ------------------------------------------
sharding = _extract_section("sharding")
_pick("sharding_strategy", sharding)
_pick("num_shards", sharding)
_pick("max_points_per_shard", sharding)
_pick("min_points_per_shard", sharding)
_pick("min_points_per_param", sharding)
# Flat fallback
for _f in (
"sharding_strategy",
"num_shards",
"max_points_per_shard",
"min_points_per_shard",
"min_points_per_param",
):
_pick(_f, config_dict)
# --- backend_config section ------------------------------------
backend = _extract_section("backend_config")
_pick("backend_name", backend)
_pick("enable_checkpoints", backend)
_pick("checkpoint_dir", backend)
_pick("chain_method", backend)
# Flat fallback
for _f in (
"backend_name",
"enable_checkpoints",
"checkpoint_dir",
"chain_method",
):
_pick(_f, config_dict)
# Coerce removed "gpu" backend to "auto" (heterodyne is CPU-only)
if kwargs.get("backend_name") == "gpu":
warnings.warn(
"backend_name='gpu' is not supported; heterodyne is CPU-only. "
"Falling back to 'auto'.",
DeprecationWarning,
stacklevel=2,
)
kwargs["backend_name"] = "auto"
# --- per_shard_mcmc section ------------------------------------
mcmc = _extract_section("per_shard_mcmc")
_pick("num_warmup", mcmc)
_pick("num_samples", mcmc)
_pick("num_chains", mcmc)
_pick("target_accept_prob", mcmc)
# Accept legacy key name (homodyne parity rename).
if "target_accept" in mcmc and "target_accept_prob" not in kwargs:
kwargs["target_accept_prob"] = mcmc["target_accept"]
_pick("max_tree_depth", mcmc)
_pick("seed", mcmc)
_pick("dense_mass", mcmc)
_pick("init_strategy", mcmc)
_pick("adaptive_sampling", mcmc)
_pick("min_warmup", mcmc)
_pick("min_samples", mcmc)
_pick("fast_warmup", mcmc)
# Flat fallbacks (including legacy target_accept)
for _f in (
"num_warmup",
"num_samples",
"num_chains",
"target_accept_prob",
"max_tree_depth",
"seed",
"dense_mass",
"init_strategy",
"adaptive_sampling",
"min_warmup",
"min_samples",
"fast_warmup",
):
_pick(_f, config_dict)
if "target_accept" in config_dict and "target_accept_prob" not in kwargs:
kwargs["target_accept_prob"] = config_dict["target_accept"]
# --- validation section ----------------------------------------
validation = _extract_section("validation")
_pick("max_r_hat", validation)
# Accept legacy key name (homodyne parity rename).
if "r_hat_threshold" in validation and "max_r_hat" not in kwargs:
kwargs["max_r_hat"] = validation["r_hat_threshold"]
_pick("min_ess", validation)
_pick("min_bfmi", validation)
_pick("max_divergence_rate", validation)
_pick("require_nlsq_warmstart", validation)
_pick("max_parameter_cv", validation)
_pick("heterogeneity_abort", validation)
_pick("allow_degenerate_warmstart", validation)
# Flat fallbacks
for _f in (
"max_r_hat",
"min_ess",
"min_bfmi",
"max_divergence_rate",
"require_nlsq_warmstart",
"max_parameter_cv",
"heterogeneity_abort",
"allow_degenerate_warmstart",
):
_pick(_f, config_dict)
if "r_hat_threshold" in config_dict and "max_r_hat" not in kwargs:
kwargs["max_r_hat"] = config_dict["r_hat_threshold"]
# --- nlsq section ----------------------------------------------
nlsq = _extract_section("nlsq")
_pick("use_nlsq_warmstart", nlsq)
_pick("use_nlsq_informed_priors", nlsq)
_pick("nlsq_prior_width_factor", nlsq)
_pick("use_log_space_priors", nlsq)
# Accept legacy key name (homodyne parity rename).
if "prior_width_factor" in nlsq and "nlsq_prior_width_factor" not in kwargs:
kwargs["nlsq_prior_width_factor"] = nlsq["prior_width_factor"]
# Flat fallbacks
for _f in (
"use_nlsq_warmstart",
"use_nlsq_informed_priors",
"nlsq_prior_width_factor",
"use_log_space_priors",
):
_pick(_f, config_dict)
if (
"prior_width_factor" in config_dict
and "nlsq_prior_width_factor" not in kwargs
):
kwargs["nlsq_prior_width_factor"] = config_dict["prior_width_factor"]
# --- combination section ---------------------------------------
combination = _extract_section("combination")
_pick("combination_method", combination)
_pick("min_success_rate", combination)
_pick("min_success_rate_warning", combination)
# Flat fallbacks
for _f in (
"combination_method",
"min_success_rate",
"min_success_rate_warning",
):
_pick(_f, config_dict)
# --- timeout section -------------------------------------------
timeout = _extract_section("timeout")
_pick("per_shard_timeout", timeout)
_pick("heartbeat_timeout", timeout)
# Flat fallbacks
for _f in ("per_shard_timeout", "heartbeat_timeout"):
_pick(_f, config_dict)
# --- reparameterization section --------------------------------
reparam = _extract_section("reparameterization")
_pick("use_reparam", reparam)
_pick("reparameterization_d_total", reparam)
_pick("reparameterization_log_gamma", reparam)
_pick("bimodal_min_weight", reparam)
_pick("bimodal_min_separation", reparam)
# Flat fallbacks
for _f in (
"use_reparam",
"reparameterization_d_total",
"reparameterization_log_gamma",
"bimodal_min_weight",
"bimodal_min_separation",
):
_pick(_f, config_dict)
# --- profiling section -----------------------------------------
profiling = _extract_section("profiling")
_pick("enable_jax_profiling", profiling)
_pick("jax_profile_dir", profiling)
# Flat fallbacks
for _f in ("enable_jax_profiling", "jax_profile_dir"):
_pick(_f, config_dict)
# --- Warn on unrecognised top-level keys -----------------------
_known_top_level: frozenset[str] = frozenset(
{
"enable",
"min_points_for_cmc",
"run_id",
"prior_tempering",
"per_angle",
"sharding",
"backend_config",
"per_shard_mcmc",
"validation",
"nlsq",
"combination",
"timeout",
"reparameterization",
"profiling",
# Legacy flat keys accepted above.
"per_angle_mode",
"constant_scaling_threshold",
"sharding_strategy",
"num_shards",
"max_points_per_shard",
"min_points_per_shard",
"min_points_per_param",
"backend_name",
"enable_checkpoints",
"checkpoint_dir",
"chain_method",
"num_warmup",
"num_samples",
"num_chains",
"target_accept_prob",
"target_accept",
"max_tree_depth",
"seed",
"dense_mass",
"init_strategy",
"adaptive_sampling",
"min_warmup",
"min_samples",
"fast_warmup",
"max_r_hat",
"r_hat_threshold",
"min_ess",
"min_bfmi",
"max_divergence_rate",
"require_nlsq_warmstart",
"max_parameter_cv",
"heterogeneity_abort",
"allow_degenerate_warmstart",
"use_nlsq_warmstart",
"use_nlsq_informed_priors",
"nlsq_prior_width_factor",
"prior_width_factor",
"use_log_space_priors",
"combination_method",
"min_success_rate",
"min_success_rate_warning",
"per_shard_timeout",
"heartbeat_timeout",
"use_reparam",
"reparameterization_d_total",
"reparameterization_log_gamma",
"bimodal_min_weight",
"bimodal_min_separation",
"enable_jax_profiling",
"jax_profile_dir",
}
)
unknown = sorted(set(config_dict.keys()) - _known_top_level)
if unknown:
warnings.warn(
f"CMCConfig.from_dict: unrecognised key(s) {unknown!r} will be ignored.",
stacklevel=2,
)
logger.debug(
"CMCConfig.from_dict: constructed with %d kwargs from %d input keys.",
len(kwargs),
len(config_dict),
)
return cls(**kwargs)