"""NLSQ-informed prior construction for heterodyne CMC analysis.
Builds NumPyro distribution dictionaries from NLSQ warm-start results
or from the parameter registry defaults, including log-space priors
for parameters flagged with ``log_space=True``.
"""
from __future__ import annotations
import dataclasses
import math
from typing import TYPE_CHECKING, Any
import numpy as np
import numpyro.distributions as dist
from numpyro.distributions.truncated import TwoSidedTruncatedDistribution
from heterodyne.config.parameter_registry import DEFAULT_REGISTRY, ParameterRegistry
from heterodyne.utils.logging import get_logger
if TYPE_CHECKING:
from heterodyne.config.parameter_space import ParameterSpace
from heterodyne.optimization.nlsq.results import NLSQResult
logger = get_logger(__name__)
def _check_registry_spec_sync(registry: ParameterRegistry) -> None:
"""Walk registry vs ``_DEFAULT_PRIOR_SPECS``; raise on drift.
Inlined sync check used by both :func:`_verify_dual_prior_sync` (the
``build_default_priors`` side-effect path) and
:meth:`PriorBuilder._verify_registry_spec_sync` (the explicit-builder
path). Keeping the implementation in a module-level function avoids
the previous recursion trap where ``_verify_dual_prior_sync``
constructed a ``PriorBuilder``, which called ``build_default_priors``,
which called ``_verify_dual_prior_sync`` again.
Tolerance is ``rel_tol=1e-6, abs_tol=1e-9`` — tight enough to catch
real desync, generous enough to survive ``0.1 + 0.2`` rounding.
"""
# Lazy import to avoid an import cycle at module load:
# parameter_space.py imports parameter_registry; this module imports
# both at runtime so we don't add to the import graph.
from heterodyne.config.parameter_space import _DEFAULT_PRIOR_SPECS
mismatches: list[str] = []
for name, (spec_loc, spec_scale) in _DEFAULT_PRIOR_SPECS.items():
try:
info = registry[name]
except KeyError:
mismatches.append(
f"{name!r}: present in _DEFAULT_PRIOR_SPECS "
f"({spec_loc}, {spec_scale}) but absent from the registry"
)
continue
if info.prior_mean is None or info.prior_std is None:
mismatches.append(
f"{name!r}: registry has prior_mean={info.prior_mean!r}, "
f"prior_std={info.prior_std!r}; spec wants "
f"({spec_loc}, {spec_scale})"
)
continue
mean_ok = math.isclose(
info.prior_mean,
spec_loc,
rel_tol=_DEFAULT_SYNC_REL_TOL,
abs_tol=_DEFAULT_SYNC_ABS_TOL,
)
std_ok = math.isclose(
info.prior_std,
spec_scale,
rel_tol=_DEFAULT_SYNC_REL_TOL,
abs_tol=_DEFAULT_SYNC_ABS_TOL,
)
if not (mean_ok and std_ok):
mismatches.append(
f"{name!r}: registry=({info.prior_mean}, {info.prior_std}), "
f"spec=({spec_loc}, {spec_scale})"
)
spec_names = set(_DEFAULT_PRIOR_SPECS.keys())
for name in registry:
if name in spec_names:
continue
info = registry[name]
if info.prior_mean is not None or info.prior_std is not None:
mismatches.append(
f"{name!r}: registry has prior "
f"(prior_mean={info.prior_mean}, prior_std={info.prior_std}) "
"but is missing from _DEFAULT_PRIOR_SPECS"
)
if mismatches:
raise RuntimeError(
"Dual-prior sync violation (CLAUDE.md Rule 9). "
"parameter_registry.py and parameter_space._DEFAULT_PRIOR_SPECS "
"must agree on (prior_mean, prior_std) vs (loc, scale) per "
"parameter, within rel_tol=1e-6 / abs_tol=1e-9. Mismatches:\n - "
+ "\n - ".join(mismatches)
)
logger.debug(
"Registry/spec sync verified (%d parameters)",
len(_DEFAULT_PRIOR_SPECS),
)
def _verify_dual_prior_sync(registry: ParameterRegistry | None) -> None:
"""Run the construction-time dual-prior sync gate (CLAUDE.md Rule 9).
Built as a side-effect of :func:`build_default_priors` so every prior-
construction call validates Rule 9 without instantiating the
:class:`PriorBuilder` (which would cause unbounded recursion through
``PriorBuilder.build`` → ``build_default_priors`` → here).
"""
_check_registry_spec_sync(registry if registry is not None else DEFAULT_REGISTRY)
[docs]
def build_default_priors(
param_space: ParameterSpace,
registry: ParameterRegistry | None = None,
use_log_space_priors: bool = True,
) -> dict[str, dist.Distribution]:
"""Build default priors from the parameter registry.
Uses ``prior_mean`` and ``prior_std`` from each parameter's
:class:`~heterodyne.config.parameter_registry.ParameterInfo`.
All bounded parameters use TruncatedNormal so that ``temper_priors``
can scale them by ``sqrt(K)`` for Consensus Monte Carlo sharding.
Codex S1: when ``use_log_space_priors=True`` (default), parameters
flagged ``log_space=True`` in the registry (D0_ref, D0_sample, v0)
are overridden with LogNormal priors via :func:`build_log_space_priors`.
LogNormal mass-matrix conditioning is much better for prefactors that
span several orders of magnitude. Pass ``use_log_space_priors=False``
to fall back to TruncatedNormal uniformly.
The reparameterized path (``CMCConfig.use_reparam=True``) is unaffected:
it samples log_X_at_tref directly and never calls this function for
those parameters.
Args:
param_space: Parameter space defining which parameters vary
and their physical bounds.
registry: Parameter registry to read metadata from.
Defaults to :data:`DEFAULT_REGISTRY`.
use_log_space_priors: Apply log-space priors to ``log_space=True``
parameters from the registry. Default ``True``.
Returns:
Dictionary mapping parameter names to NumPyro distributions.
Only includes parameters in ``param_space.varying_names``.
"""
if registry is None:
registry = DEFAULT_REGISTRY
# Gemini G2: enforce dual-prior sync (CLAUDE.md Rule 9) at
# construction time only on the default path. Callers who pass
# an explicit ``registry=`` (e.g. unit tests using stub registries
# to isolate one code path) are trusted to manage sync themselves.
_verify_dual_prior_sync(None)
priors: dict[str, dist.Distribution] = {}
for name in param_space.varying_names:
info = registry[name]
low, high = param_space.bounds[name]
if (
info.prior_mean is not None
and info.prior_std is not None
and info.prior_std > 0
):
# Truncated normal centered on registry prior
priors[name] = dist.TruncatedNormal(
loc=info.prior_mean,
scale=info.prior_std,
low=low,
high=high,
)
logger.debug(
"Default prior for %s: TruncatedNormal(loc=%.4e, scale=%.4e)",
name,
info.prior_mean,
info.prior_std,
)
else:
# Uniform fallback
priors[name] = dist.Uniform(low=low, high=high)
logger.debug(
"Default prior for %s: Uniform(%.4e, %.4e)",
name,
low,
high,
)
# Codex S1: overlay LogNormal priors for parameters flagged log_space=True.
# build_log_space_priors silently skips parameters that aren't flagged,
# so this only touches D0_ref / D0_sample / v0 (or whatever the registry
# declares); other parameters keep their TruncatedNormal/Uniform priors.
if use_log_space_priors:
log_priors = build_log_space_priors(
list(param_space.varying_names), registry=registry
)
for name, prior in log_priors.items():
priors[name] = prior
logger.info(
"Built %d default priors from registry (use_log_space_priors=%s)",
len(priors),
use_log_space_priors,
)
return priors
[docs]
def build_log_space_priors(
param_names: list[str],
registry: ParameterRegistry | None = None,
) -> dict[str, dist.Distribution]:
"""Build log-normal priors for parameters marked ``log_space=True``.
For parameters where the registry's ``log_space`` flag is set, this
constructs a LogNormal distribution whose median matches the
registry ``prior_mean`` (or the parameter default) and whose
spread corresponds to the registry ``prior_std``.
Parameters not flagged as ``log_space`` are silently skipped.
Args:
param_names: List of parameter names to consider.
registry: Parameter registry. Defaults to :data:`DEFAULT_REGISTRY`.
Returns:
Dictionary mapping parameter names to LogNormal distributions.
Only includes parameters where ``log_space=True``.
"""
if registry is None:
registry = DEFAULT_REGISTRY
priors: dict[str, dist.Distribution] = {}
for name in param_names:
info = registry[name]
if not info.log_space:
continue
# Determine the location (mode center) in original space
if info.prior_mean is not None and info.prior_mean > 0:
center = info.prior_mean
elif info.default > 0:
center = info.default
else:
# Cannot construct LogNormal for non-positive center
logger.warning(
"Skipping log-space prior for %s: "
"prior_mean=%s, default=%s (both non-positive)",
name,
info.prior_mean,
info.default,
)
continue
# LogNormal parameterization: median = exp(mu), so mu = log(center)
mu = math.log(center)
# Sigma in log-space: if prior_std is available, use coefficient
# of variation to set log-space spread.
# CV = prior_std / prior_mean, and for LogNormal:
# sigma = sqrt(log(1 + CV^2))
if info.prior_std is not None and info.prior_std > 0 and center > 0:
cv = info.prior_std / center
sigma = math.sqrt(math.log1p(cv**2))
else:
# Default: moderate uncertainty (CV ~ 1.0 -> sigma ~ 0.83)
sigma = 1.0
# Floor sigma to avoid degenerate distributions
sigma = max(sigma, 0.01)
priors[name] = dist.LogNormal(loc=mu, scale=sigma)
logger.debug(
"Log-space prior for %s: LogNormal(loc=%.4f, scale=%.4f) [median=%.4e]",
name,
mu,
sigma,
center,
)
logger.info(
"Built %d log-space priors from %d candidates",
len(priors),
len(param_names),
)
return priors
# ---------------------------------------------------------------------------
# Consensus MC prior tempering
# ---------------------------------------------------------------------------
[docs]
def temper_priors(
priors: dict[str, dist.Distribution],
num_shards: int,
) -> dict[str, dist.Distribution]:
"""Scale prior widths for Consensus MC shard sub-posteriors.
Each shard sees 1/K of the data, so the prior should be tempered by
``sqrt(K)`` to maintain proper posterior geometry when K
sub-posteriors are combined via the consensus step.
Supported distribution types and their tempering rules:
* ``TruncatedNormal`` — scale multiplied by ``sqrt(K)``.
* ``LogNormal`` — scale multiplied by ``sqrt(K)``.
* ``Uniform`` — left unchanged (uninformative; no tempering needed).
* All others — kept unchanged with a warning logged.
Args:
priors: Dict of NumPyro distributions, one per varying parameter.
num_shards: Number of CMC shards (K). Must be >= 1.
Returns:
New dict with tempered distributions. Existing dict is not
mutated.
"""
if num_shards < 1:
raise ValueError(f"num_shards must be >= 1, got {num_shards}")
if num_shards == 1:
# No tempering needed for a single shard
logger.debug("temper_priors: num_shards=1, returning priors unchanged.")
return dict(priors)
factor = math.sqrt(num_shards)
tempered: dict[str, dist.Distribution] = {}
for name, prior in priors.items():
if isinstance(prior, TwoSidedTruncatedDistribution):
# TwoSidedTruncatedDistribution wraps a Normal base_dist
loc = float(prior.base_dist.loc)
old_scale = float(prior.base_dist.scale)
scale = old_scale * factor
low = float(prior.low)
high = float(prior.high)
tempered[name] = dist.TruncatedNormal(
loc=loc, scale=scale, low=low, high=high
)
logger.debug(
"temper_priors: %s TruncatedNormal scale %.4e -> %.4e (x%.2f)",
name,
old_scale,
scale,
factor,
)
elif isinstance(prior, dist.LogNormal):
loc = float(prior.loc)
scale = float(prior.scale) * factor
tempered[name] = dist.LogNormal(loc=loc, scale=scale)
logger.debug(
"temper_priors: %s LogNormal scale %.4e -> %.4e (x%.2f)",
name,
float(prior.scale),
scale,
factor,
)
elif isinstance(prior, dist.Uniform):
# Uninformative; keep unchanged
tempered[name] = prior
logger.debug("temper_priors: %s Uniform — left unchanged.", name)
elif isinstance(prior, dist.TransformedDistribution):
# BetaScaled: cannot easily temper — leave unchanged with warning
logger.warning(
"temper_priors: %s has TransformedDistribution (e.g. BetaScaled) — "
"left unchanged. Consider using TruncatedNormal for tempered CMC.",
name,
)
tempered[name] = prior
else:
# Unsupported type; keep unchanged and warn
logger.warning(
"temper_priors: %s has unsupported type %s — left unchanged. "
"Consider using TruncatedNormal or LogNormal for proper tempering.",
name,
type(prior).__name__,
)
tempered[name] = prior
logger.info(
"Tempered %d priors for %d shards (scale factor=%.4f).",
len(tempered),
num_shards,
factor,
)
return tempered
# ---------------------------------------------------------------------------
# Prior validation
# ---------------------------------------------------------------------------
[docs]
def validate_priors(
priors: dict[str, dist.Distribution],
param_space: ParameterSpace,
) -> list[str]:
"""Validate prior distributions against parameter space.
Checks:
1. All varying parameters have a corresponding prior.
2. Prior support overlaps with the parameter bounds (non-empty
intersection).
3. No degenerate (effectively zero-width) priors.
A prior is considered degenerate when its extractable scale is
below ``1e-12``. Uniform priors are never degenerate.
Args:
priors: Dict of NumPyro distributions.
param_space: Defines varying parameter names and their bounds.
Returns:
List of warning/error strings. Empty list means all checks
passed.
"""
issues: list[str] = []
for name in param_space.varying_names:
# Check 1: prior exists
if name not in priors:
issues.append(f"Missing prior for varying parameter '{name}'.")
continue
prior = priors[name]
low_bound, high_bound = param_space.bounds[name]
# Check 2: support overlap with bounds
# For distributions with explicit support attributes
if isinstance(prior, TwoSidedTruncatedDistribution):
prior_low = float(prior.low)
prior_high = float(prior.high)
if prior_high <= low_bound or prior_low >= high_bound:
issues.append(
f"Prior for '{name}' support [{prior_low:.4e}, {prior_high:.4e}] "
f"does not overlap with bounds [{low_bound:.4e}, {high_bound:.4e}]."
)
elif isinstance(prior, dist.Uniform):
prior_low = float(prior.low)
prior_high = float(prior.high)
if prior_high <= low_bound or prior_low >= high_bound:
issues.append(
f"Prior for '{name}' Uniform[{prior_low:.4e}, {prior_high:.4e}] "
f"does not overlap with bounds [{low_bound:.4e}, {high_bound:.4e}]."
)
elif isinstance(prior, dist.LogNormal):
# LogNormal has support (0, inf); check lower bound is >= 0
if high_bound <= 0:
issues.append(
f"Prior for '{name}' is LogNormal (support > 0) but "
f"upper bound is {high_bound:.4e} <= 0."
)
elif isinstance(prior, dist.TransformedDistribution):
# TransformedDistribution (e.g. BetaScaled) — trust the bounds
# are correct since they're set during construction
pass
# Check 3: degenerate prior (near-zero scale)
_scale: float | None = None
if isinstance(prior, TwoSidedTruncatedDistribution):
_scale = float(prior.base_dist.scale)
elif isinstance(prior, dist.LogNormal):
_scale = float(prior.scale)
elif isinstance(prior, dist.TransformedDistribution):
# Check base distribution for degeneracy
base = prior.base_dist
if isinstance(base, dist.Beta):
# Beta is never degenerate if both concentrations > 0
pass
# Uniform is never degenerate by construction
if _scale is not None and _scale < 1e-12:
issues.append(
f"Prior for '{name}' is degenerate: scale={_scale:.2e} < 1e-12."
)
# Report any priors defined for non-varying parameters (informational)
varying_set = set(param_space.varying_names)
for name in priors:
if name not in varying_set:
issues.append(
f"Prior defined for '{name}' but it is not a varying parameter. "
"This prior will be ignored by the sampler."
)
if issues:
logger.warning(
"validate_priors: %d issue(s) found:\n %s",
len(issues),
"\n ".join(issues),
)
else:
logger.info("validate_priors: all %d priors passed validation.", len(priors))
return issues
# ---------------------------------------------------------------------------
# Prior summary
# ---------------------------------------------------------------------------
[docs]
def summarize_priors(priors: dict[str, dist.Distribution]) -> str:
"""Format a human-readable summary of prior distributions.
For each prior, reports the distribution type and, where applicable,
the mean, standard deviation, and support interval.
Args:
priors: Dict of NumPyro distributions.
Returns:
Multi-line string with one row per parameter.
"""
if not priors:
return "No priors defined."
lines: list[str] = ["Prior summary:"]
name_width = max(len(n) for n in priors) + 2
for name, prior in priors.items():
label = f" {name:<{name_width}}"
if isinstance(prior, TwoSidedTruncatedDistribution):
loc = float(prior.base_dist.loc)
scale = float(prior.base_dist.scale)
low = float(prior.low)
high = float(prior.high)
lines.append(
f"{label}TruncatedNormal "
f"loc={loc:.4e} scale={scale:.4e} "
f"support=[{low:.4e}, {high:.4e}]"
)
elif isinstance(prior, dist.LogNormal):
loc = float(prior.loc)
scale = float(prior.scale)
# Median of LogNormal = exp(loc); mean = exp(loc + scale^2/2)
median = math.exp(loc)
mean = math.exp(loc + 0.5 * scale**2)
std = math.sqrt((math.exp(scale**2) - 1.0) * math.exp(2.0 * loc + scale**2))
lines.append(
f"{label}LogNormal "
f"loc={loc:.4f} scale={scale:.4f} "
f"median={median:.4e} mean={mean:.4e} std={std:.4e} "
f"support=(0, +inf)"
)
elif isinstance(prior, dist.Uniform):
low = float(prior.low)
high = float(prior.high)
mean = (low + high) / 2.0
std = (high - low) / math.sqrt(12.0)
lines.append(
f"{label}Uniform "
f"support=[{low:.4e}, {high:.4e}] "
f"mean={mean:.4e} std={std:.4e}"
)
elif isinstance(prior, dist.TransformedDistribution):
base = prior.base_dist
if isinstance(base, dist.Beta):
conc1 = float(base.concentration1)
conc2 = float(base.concentration0)
# Extract affine transform parameters
transforms = prior.transforms
if transforms:
t = transforms[0] # AffineTransform
loc = float(t.loc)
scale = float(t.scale)
mean = loc + scale * conc1 / (conc1 + conc2)
lines.append(
f"{label}BetaScaled "
f"alpha={conc1:.4f} beta={conc2:.4f} "
f"support=[{loc:.4e}, {loc + scale:.4e}] "
f"mean={mean:.4e}"
)
else:
lines.append(
f"{label}TransformedBeta alpha={conc1:.4f} beta={conc2:.4f}"
)
else:
lines.append(f"{label}Transformed({type(base).__name__})")
else:
lines.append(f"{label}{type(prior).__name__}")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Parameter name helpers
# ---------------------------------------------------------------------------
[docs]
def get_param_names_in_order(
vary_flags: dict[str, bool] | None = None,
) -> list[str]:
"""Return the ordered list of parameter names that are set to vary.
Iteration order matches the registry's insertion order, which
follows the canonical group ordering defined in
``parameter_names.py`` (reference → sample → velocity → fraction →
angle → scaling).
Args:
vary_flags: Optional override dict mapping parameter name to a
bool indicating whether that parameter varies. Parameters
absent from ``vary_flags`` fall back to the registry's
``vary_default`` attribute. Pass ``None`` to use registry
defaults for all parameters.
Returns:
List of parameter names for which the effective ``vary`` flag is
``True``, in registry order.
"""
vary_flags = vary_flags or {}
names: list[str] = []
for name in DEFAULT_REGISTRY:
info = DEFAULT_REGISTRY[name]
effective_vary = vary_flags.get(name, info.vary_default)
if effective_vary:
names.append(name)
logger.debug(
"get_param_names_in_order: %d varying parameters selected.", len(names)
)
return names
# ---------------------------------------------------------------------------
# Initial value construction and validation
# ---------------------------------------------------------------------------
[docs]
def validate_initial_value_bounds(
init_values: dict[str, float],
param_specs: dict[str, Any] | None = None,
) -> dict[str, list[str]]:
"""Check that each initial value lies within the parameter's bounds.
Args:
init_values: Mapping of parameter name to proposed initial value.
param_specs: Optional dict of ``{name: {min_bound, max_bound}}``
overrides. When not provided, bounds are read from
:data:`DEFAULT_REGISTRY`.
Returns:
Mapping from parameter name to a list of warning strings. An
empty dict indicates all values are within bounds.
"""
issues: dict[str, list[str]] = {}
for name, value in init_values.items():
# Determine bounds
if param_specs and name in param_specs:
spec = param_specs[name]
low = float(spec.get("min_bound", -math.inf))
high = float(spec.get("max_bound", math.inf))
elif name in DEFAULT_REGISTRY:
info = DEFAULT_REGISTRY[name]
low = info.min_bound
high = info.max_bound
else:
# Unknown parameter — skip bounds check but warn
logger.warning(
"validate_initial_value_bounds: unknown parameter '%s', "
"not in registry or param_specs — skipping bounds check.",
name,
)
continue
param_issues: list[str] = []
if value < low:
param_issues.append(f"Value {value:.4e} is below min_bound {low:.4e}.")
if value > high:
param_issues.append(f"Value {value:.4e} is above max_bound {high:.4e}.")
if param_issues:
issues[name] = param_issues
logger.warning(
"validate_initial_value_bounds: %s — %s",
name,
"; ".join(param_issues),
)
if not issues:
logger.debug(
"validate_initial_value_bounds: all %d values within bounds.",
len(init_values),
)
return issues
[docs]
def build_init_values_dict(
nlsq_values: dict[str, float] | None = None,
vary_flags: dict[str, bool] | None = None,
fallback: str = "prior_mean",
) -> dict[str, float]:
"""Build an initial-values dict for NUTS warm-starting.
For each varying parameter the value is resolved in order:
1. NLSQ estimate from ``nlsq_values`` (if available).
2. Registry ``prior_mean`` when ``fallback="prior_mean"`` and
``prior_mean`` is not ``None``.
3. Registry ``default`` value.
All resolved values are validated against bounds and clamped when
necessary, with a logged warning per clamped parameter.
Args:
nlsq_values: Optional NLSQ MAP estimates keyed by parameter
name.
vary_flags: Optional dict controlling which parameters vary (see
:func:`get_param_names_in_order`).
fallback: Strategy for parameters absent from ``nlsq_values``.
``"prior_mean"`` uses the registry prior mean (default);
``"default"`` uses the registry default value.
Returns:
Dict mapping each varying parameter name to its initial value,
ready to pass to :meth:`~heterodyne.optimization.cmc.sampler.NUTSSampler.run_with_init_values`.
"""
if fallback not in {"prior_mean", "default"}:
raise ValueError(
f"fallback must be 'prior_mean' or 'default', got {fallback!r}"
)
nlsq_values = nlsq_values or {}
param_names = get_param_names_in_order(vary_flags)
init_values: dict[str, float] = {}
for name in param_names:
info = DEFAULT_REGISTRY[name]
# 1. NLSQ estimate
if name in nlsq_values:
value = float(nlsq_values[name])
# 2/3. Fallback
elif fallback == "prior_mean" and info.prior_mean is not None:
value = float(info.prior_mean)
else:
value = float(info.default)
# Clamp to bounds and warn if adjusted
clamped = float(max(info.min_bound, min(info.max_bound, value)))
if clamped != value:
logger.warning(
"build_init_values_dict: %s initial value %.4e clamped to [%.4e, %.4e] -> %.4e",
name,
value,
info.min_bound,
info.max_bound,
clamped,
)
init_values[name] = clamped
logger.info(
"build_init_values_dict: built %d initial values (nlsq=%d, fallback=%r).",
len(init_values),
sum(1 for n in param_names if n in nlsq_values),
fallback,
)
return init_values
# ---------------------------------------------------------------------------
# NLSQ value extraction for CMC warm-starting
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Per-angle scaling estimation
# ---------------------------------------------------------------------------
[docs]
def estimate_per_angle_scaling(
data_dict: dict[str, Any],
angle_keys: list[str] | None = None,
) -> dict[str, tuple[float, float]]:
"""Estimate contrast and offset scaling per scattering angle.
Uses simple heuristics on the raw g2 correlation data to provide
starting-point estimates for the ``contrast`` and ``offset``
scaling parameters:
* ``contrast`` estimate ≈ ``max(g2) - min(g2)`` over the full
lag range.
* ``offset`` estimate ≈ mean of the last 10 % of g2 values
(long-lag baseline), clamped to ``[0, 1]``.
These are heuristics suitable for warm-starting, not MAP estimates.
The NLSQ/MCMC optimisation will refine them.
Args:
data_dict: Dict mapping angle keys to g2 data. Each value may
be:
* a 1-D array-like ``(n_lags,)`` of g2 values, or
* a dict with a ``"g2"`` key holding such an array.
angle_keys: Subset of keys in ``data_dict`` to process.
Defaults to all keys when ``None``.
Returns:
Mapping of ``angle_key -> (contrast_estimate, offset_estimate)``.
Keys for which data could not be parsed are silently omitted.
"""
import numpy as np
if angle_keys is None:
angle_keys = list(data_dict.keys())
result: dict[str, tuple[float, float]] = {}
for key in angle_keys:
raw = data_dict.get(key)
if raw is None:
logger.warning(
"estimate_per_angle_scaling: key '%s' not found in data_dict.",
key,
)
continue
# Accept either a plain array or a dict with a "g2" sub-key
if isinstance(raw, dict):
g2_raw = raw.get("g2")
if g2_raw is None:
logger.debug(
"estimate_per_angle_scaling: key '%s' dict has no 'g2' entry.",
key,
)
continue
else:
g2_raw = raw
try:
g2 = np.asarray(g2_raw, dtype=float).ravel()
except (ValueError, TypeError) as exc:
logger.debug(
"estimate_per_angle_scaling: cannot convert key '%s' to array: %s",
key,
exc,
)
continue
if g2.size == 0:
continue
g2_max = float(np.nanmax(g2))
g2_min = float(np.nanmin(g2))
contrast_est = g2_max - g2_min
# Baseline: mean of the last 10 % of points (long-tau asymptote)
n_tail = max(1, int(np.ceil(0.1 * g2.size)))
offset_est = float(np.nanmean(g2[-n_tail:]))
# Physical constraint: offset in [0, 1]
offset_est = float(np.clip(offset_est, 0.0, 1.0))
result[key] = (contrast_est, offset_est)
logger.debug(
"estimate_per_angle_scaling: key='%s', contrast=%.4e, offset=%.4e",
key,
contrast_est,
offset_est,
)
logger.info(
"estimate_per_angle_scaling: estimated scaling for %d / %d angles.",
len(result),
len(angle_keys),
)
return result
[docs]
def estimate_contrast_offset_from_data(
c2_data: np.ndarray,
t1: np.ndarray,
t2: np.ndarray,
contrast_bounds: tuple[float, float] = (0.0, 1.0),
offset_bounds: tuple[float, float] = (0.5, 1.5),
lag_floor_quantile: float = 0.80,
lag_ceiling_quantile: float = 0.20,
value_quantile_low: float = 0.10,
value_quantile_high: float = 0.90,
) -> tuple[float, float]:
"""Estimate contrast and offset from C2 data via physics-informed quantile analysis.
Uses the correlation decay: C2 = contrast × g1² + offset.
At large lags g1² → 0 so C2 → offset; at small lags g1² ≈ 1
so C2 ≈ contrast + offset.
Returns
-------
tuple[float, float]
``(contrast_est, offset_est)`` each clipped to their bounds.
"""
delta_t = np.abs(np.asarray(t1) - np.asarray(t2)).ravel()
c2 = np.asarray(c2_data).ravel()
if len(c2) < 100:
contrast_mid = (contrast_bounds[0] + contrast_bounds[1]) / 2.0
offset_mid = (offset_bounds[0] + offset_bounds[1]) / 2.0
logger.debug(
"estimate_contrast_offset_from_data: insufficient data (%d pts), "
"returning midpoints contrast=%.3f offset=%.3f",
len(c2),
contrast_mid,
offset_mid,
)
return contrast_mid, offset_mid
lag_hi = np.percentile(delta_t, lag_floor_quantile * 100)
lag_lo = np.percentile(delta_t, lag_ceiling_quantile * 100)
# OFFSET: large-lag region where g1² ≈ 0
mask_hi = delta_t >= lag_hi
if np.sum(mask_hi) >= 10:
offset_est = np.percentile(c2[mask_hi], value_quantile_low * 100)
else:
offset_est = np.percentile(c2, value_quantile_low * 100)
offset_est = float(np.clip(offset_est, offset_bounds[0], offset_bounds[1]))
# CONTRAST: small-lag region where g1² ≈ 1
mask_lo = delta_t <= lag_lo
if np.sum(mask_lo) >= 10:
ceiling = np.percentile(c2[mask_lo], value_quantile_high * 100)
else:
ceiling = np.percentile(c2, value_quantile_high * 100)
contrast_est = float(
np.clip(ceiling - offset_est, contrast_bounds[0], contrast_bounds[1])
)
logger.debug(
"estimate_contrast_offset_from_data: offset=%.4f contrast=%.4f",
offset_est,
contrast_est,
)
return contrast_est, offset_est
[docs]
def validate_init_values_order(
init_values: dict[str, float],
expected_names: list[str],
) -> None:
"""Validate that init-values key order matches the expected parameter order.
Homodyne CMC parity helper. Python 3.7+ dicts preserve insertion order,
so positional consumers of ``init_values`` (e.g. when zipping with NLSQ
arrays) depend on a stable iteration order. This check makes ordering
bugs fail loudly with a descriptive error instead of producing silently
wrong parameter bindings.
Args:
init_values: Initial-values mapping, typically the output of
:func:`build_init_values_dict`.
expected_names: Required parameter order — usually
:func:`get_param_names_in_order` for the active mode.
Raises:
ValueError: When the key count or per-position order disagrees with
``expected_names``. The error names the first mismatching index
and quotes the full lists for fast diagnosis.
"""
actual_names = list(init_values.keys())
if actual_names == expected_names:
return
if len(actual_names) != len(expected_names):
raise ValueError(
"Parameter count mismatch in init_values:\n"
f" expected {len(expected_names)} params: {expected_names}\n"
f" actual {len(actual_names)} params: {actual_names}"
)
for i, (actual, expected) in enumerate(
zip(actual_names, expected_names, strict=False)
):
if actual != expected:
raise ValueError(
"Parameter order mismatch at position "
f"{i}:\n expected: {expected!r}\n actual: {actual!r}\n"
f" full expected: {expected_names}\n"
f" full actual: {actual_names}"
)
# ---------------------------------------------------------------------------
# Absorbed from optimization/cmc/warmstart.py (Phase 4 PR 4 Task 4.4; spec §7 Rule 1)
# Preserves the two public call signatures + geometric-margin log-space math
# + fixed_param_overrides arg logic from REPORT.md §Manual narrative diffs §6.
# ---------------------------------------------------------------------------
#: Default boundary-clamp margin: 5% of the bound range (linear) or 5% of
#: the log-range (geometric, for ``log_space=True`` parameters). NUTS
#: leapfrog step-size adaptation collapses if the chain initialises at a
#: TruncatedNormal boundary; this margin keeps the warm-start away from
#: the reflecting wall.
BOUNDARY_INTERIOR_MARGIN: float = 5e-2
[docs]
def clamp_params_to_interior(
params: np.ndarray,
parameter_names: list[str],
*,
margin: float = BOUNDARY_INTERIOR_MARGIN,
) -> tuple[np.ndarray, list[str]]:
"""Shift parameter values inward from hard bounds (raw-array path).
Lower-level companion to :func:`clamp_to_interior` (the NLSQResult
entry point). Named separately so the two public APIs don't share
a symbol — callers without an ``NLSQResult`` (e.g. direct Python
users passing raw arrays to :func:`fit_cmc_jax`) use this entry.
Linear-scale parameters are clamped to ``[lo + margin * range,
hi - margin * range]`` where ``range = hi - lo``. Log-space
parameters (D0_ref, D0_sample, v0 — registry ``log_space=True``,
positive ``min_bound``) use a *geometric* margin so a 5% linear
fraction of a multi-decade range does not produce a 500× clamp
target.
Args:
params: Array of parameter values, shape ``(n,)``.
parameter_names: Names corresponding to each entry in
``params``. Names not in the registry are passed through.
margin: Fraction of bound range to keep clear of each wall.
Default :data:`BOUNDARY_INTERIOR_MARGIN`.
Returns:
``(new_params, clamped_names)`` — a fresh array with values
shifted inward and the list of names that were actually moved.
"""
try:
from heterodyne.config.parameter_registry import ParameterRegistry
registry = ParameterRegistry()
except ImportError:
return np.asarray(params).copy(), []
out = np.asarray(params, dtype=float).copy()
clamped: list[str] = []
for i, name in enumerate(parameter_names):
try:
info = registry[name]
except KeyError:
continue
if info.log_space and info.min_bound > 0:
log_range = np.log(info.max_bound / info.min_bound)
factor = np.exp(margin * log_range)
lo = info.min_bound * factor
hi = info.max_bound / factor
else:
span = info.max_bound - info.min_bound
lo = info.min_bound + margin * span
hi = info.max_bound - margin * span
if lo >= hi:
# Degenerate bounds — leave the value alone rather than
# collapse it to the midpoint.
continue
old = float(out[i])
new = float(np.clip(old, lo, hi))
if new != old:
out[i] = new
clamped.append(name)
return out, clamped
[docs]
def clamp_to_interior(
result: NLSQResult,
fixed_param_overrides: dict[str, float] | None = None,
*,
margin: float = BOUNDARY_INTERIOR_MARGIN,
) -> NLSQResult:
"""Return a copy of *result* with parameters shifted inward from hard bounds.
NUTS step-size collapses when the chain initialises at a
TruncatedNormal boundary. Linear-scale parameters are clamped to
``[min_bound + margin, max_bound - margin]`` where margin is 5% of
the bound range. Log-space parameters (D0, v0) use a geometric
margin so a linear fraction of a multi-decade range does not
produce an absurd clamp target. Both ensure the leapfrog step-size
adaptation starts well away from the reflecting wall.
``fixed_param_overrides`` maps parameter names to values from the
current model config's ``fixed_parameters``. When provided, any
NLSQ result value for a fixed parameter is replaced with the config
value before bounds clamping. This prevents a stale
``nlsq_data.npz`` (fitted in a prior run where the parameter was
free) from propagating a superseded value into CMC initialisation,
which can place the warm-start outside the reparameterised prior
support and cause log-prior = −∞ → BFMI = 0 across all shards.
Args:
result: The NLSQ fit result to clamp.
fixed_param_overrides: Optional map from parameter name to
config-fixed value applied before clamping.
margin: Fraction of bound range to keep clear of each wall.
Default :data:`BOUNDARY_INTERIOR_MARGIN`.
Returns:
A new :class:`NLSQResult` with parameters clamped if any were
out of the safe interior; the original ``result`` itself when
no clamping was needed.
"""
try:
from heterodyne.config.parameter_registry import ParameterRegistry
registry = ParameterRegistry()
except ImportError:
return result
params = result.parameters.copy()
clamped: list[str] = []
for i, name in enumerate(result.parameter_names):
# Apply fixed-parameter overrides before bounds clamping.
if fixed_param_overrides and name in fixed_param_overrides:
old = float(params[i])
new = float(fixed_param_overrides[name])
if abs(new - old) > 1e-12:
logger.info(
"CMC init override: %s %.4g → %.4g "
"(fixed in current config; stale NLSQ value replaced)",
name,
old,
new,
)
params[i] = new
clamped.append(name)
continue
try:
info = registry[name]
except KeyError:
continue
if info.log_space and info.min_bound > 0:
log_range = np.log(info.max_bound / info.min_bound)
factor = np.exp(margin * log_range)
lo = info.min_bound * factor
hi = info.max_bound / factor
else:
span = info.max_bound - info.min_bound
lo = info.min_bound + margin * span
hi = info.max_bound - margin * span
if lo >= hi:
continue
old = float(params[i])
new = float(np.clip(old, lo, hi))
if new != old:
_at_lb = abs(old - info.min_bound) < 1e-10 * max(abs(info.min_bound), 1)
_at_ub = abs(old - info.max_bound) < 1e-10 * max(abs(info.max_bound), 1)
if _at_lb or _at_ub:
_side = "lower" if _at_lb else "upper"
logger.warning(
"CMC init clamp: %s %.4g → %.4g (bounds [%.4g, %.4g]); "
"NLSQ hit the %s bound exactly — true mode may be outside "
"current bounds; expect high NUTS divergence rate. "
"Consider widening the %s bound in your config.",
name,
old,
new,
info.min_bound,
info.max_bound,
_side,
_side,
)
else:
logger.warning(
"CMC init clamp: %s %.4g → %.4g (bounds [%.4g, %.4g]); "
"NUTS boundary-adjacent start prevented",
name,
old,
new,
info.min_bound,
info.max_bound,
)
clamped.append(name)
params[i] = new
if not clamped:
return result
return dataclasses.replace(result, parameters=params)
# ---------------------------------------------------------------------------
# Absorbed from optimization/cmc/prior_builder.py (Phase 4 PR 4 Task 4.5;
# spec §7 Rule 1). Preserves PriorBuilder construction-time validation
# semantics (CLAUDE.md Rule 9 dual-prior sync gate) from T1-T4 stability work.
# ---------------------------------------------------------------------------
_DEFAULT_SYNC_REL_TOL: float = 1e-6
_DEFAULT_SYNC_ABS_TOL: float = 1e-9
[docs]
class PriorBuilder:
"""Construct NumPyro priors from a parameter registry.
Args:
registry: Parameter registry (defaults to
:data:`heterodyne.config.parameter_registry.DEFAULT_REGISTRY`).
use_log_space_priors: When True (default), parameters flagged
``log_space=True`` in the registry are returned as LogNormal
distributions instead of TruncatedNormal — mass-matrix
conditioning for multi-decade prefactors (D0_ref, D0_sample,
v0). See codex S1.
Raises:
RuntimeError: When the registry and
``parameter_space._DEFAULT_PRIOR_SPECS`` disagree on
``(prior_mean, prior_std)`` vs ``(loc, scale)``. This is
Rule 9 from ``CLAUDE.md`` — the dual-prior system must
stay in sync. Tolerance is ``rel_tol=1e-6, abs_tol=1e-9``.
"""
[docs]
def __init__(
self,
registry: ParameterRegistry | None = None,
use_log_space_priors: bool = True,
) -> None:
self._registry = registry if registry is not None else DEFAULT_REGISTRY
self._use_log = bool(use_log_space_priors)
self._verify_registry_spec_sync()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
def build(self, param_space: ParameterSpace) -> dict[str, dist.Distribution]:
"""Build priors for the varying parameters in *param_space*.
Args:
param_space: ParameterSpace defining which parameters vary
and their physical bounds.
Returns:
Dictionary mapping each varying parameter name to its
NumPyro distribution. D0_ref/D0_sample/v0 are LogNormal
when ``use_log_space_priors=True``, TruncatedNormal-family
otherwise; remaining parameters are always TruncatedNormal.
"""
return build_default_priors(
param_space,
registry=self._registry,
use_log_space_priors=self._use_log,
)
# ------------------------------------------------------------------
# Construction-time invariant (CLAUDE.md Rule 9)
# ------------------------------------------------------------------
def _verify_registry_spec_sync(self) -> None:
"""Walk registry vs ``_DEFAULT_PRIOR_SPECS``; raise on drift.
Delegates to the module-level :func:`_check_registry_spec_sync`
so the same comparison logic powers both the function-style
:func:`build_default_priors` side-effect path and this
builder-style entry point. See that function for the two
sources of truth and the tolerance rationale.
"""
_check_registry_spec_sync(self._registry)
[docs]
def build_default_priors_via_builder(
param_space: ParameterSpace,
registry: ParameterRegistry | None = None,
use_log_space_priors: bool = True,
) -> dict[str, dist.Distribution]:
"""Functional shim around :class:`PriorBuilder` for callers that
want a one-shot factory. Constructing the builder runs the sync
gate, so each call validates the dual-prior invariant.
"""
return PriorBuilder(registry, use_log_space_priors).build(param_space)
[docs]
def build_log_space_priors_via_builder(
param_names: list[str],
registry: ParameterRegistry | None = None,
) -> dict[str, dist.Distribution]:
"""Functional shim for the legacy log-space-only helper.
The construction-time gate runs here too — calling this entry point
is enough to fail loud on a desynced registry/spec pair.
"""
PriorBuilder(registry, use_log_space_priors=True)
return build_log_space_priors(param_names, registry=registry)