Source code for heterodyne.config.parameter_space

"""Parameter space definition with prior distributions for Bayesian inference."""

from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any

import numpy as np

from heterodyne.config.parameter_names import (
    ALL_PARAM_NAMES,
    ALL_PARAM_NAMES_WITH_SCALING,
    SCALING_PARAMS,
)
from heterodyne.config.parameter_registry import DEFAULT_REGISTRY
from heterodyne.utils.logging import get_logger

if TYPE_CHECKING:
    import jax.numpy as jnp

logger = get_logger(__name__)

# Parameters whose value=0.0 is a documented "disable term" sentinel when fixed.
# Their min_bound is an optimizer stability floor that does not constrain fixed values.
# v0=0.0 disables the velocity term; its min_bound=1e-6 is for log-space NLSQ stability.
_FIXED_ZERO_SENTINELS: frozenset[str] = frozenset({"v0"})


[docs] class PriorType(Enum): """Available prior distribution types.""" UNIFORM = "uniform" NORMAL = "normal" TRUNCATED_NORMAL = "truncated_normal" LOGNORMAL = "lognormal" HALFNORMAL = "halfnormal" EXPONENTIAL = "exponential" BETA_SCALED = "beta_scaled"
[docs] @dataclass class PriorDistribution: """Prior distribution specification for a parameter.""" prior_type: PriorType params: dict[str, float] = field(default_factory=dict)
[docs] @classmethod def uniform(cls, low: float, high: float) -> PriorDistribution: """Create uniform prior.""" return cls(PriorType.UNIFORM, {"low": low, "high": high})
[docs] @classmethod def normal(cls, loc: float, scale: float) -> PriorDistribution: """Create normal (Gaussian) prior.""" return cls(PriorType.NORMAL, {"loc": loc, "scale": scale})
[docs] @classmethod def lognormal(cls, loc: float, scale: float) -> PriorDistribution: """Create log-normal prior (for positive parameters).""" return cls(PriorType.LOGNORMAL, {"loc": loc, "scale": scale})
[docs] @classmethod def halfnormal(cls, scale: float) -> PriorDistribution: """Create half-normal prior (for positive parameters).""" return cls(PriorType.HALFNORMAL, {"scale": scale})
[docs] @classmethod def truncated_normal( cls, loc: float, scale: float, low: float, high: float, ) -> PriorDistribution: """Create truncated normal prior (bounded Gaussian).""" return cls( PriorType.TRUNCATED_NORMAL, {"loc": loc, "scale": scale, "low": low, "high": high}, )
[docs] @classmethod def beta_scaled( cls, low: float, high: float, concentration1: float, concentration2: float, ) -> PriorDistribution: """Create a Beta prior scaled to [low, high]. The distribution is Beta(concentration1, concentration2) affine-transformed to the interval [low, high]. This is useful for bounded parameters where you want to express a prior belief about the shape within the bounds. Args: low: Lower bound of the support. high: Upper bound of the support. concentration1: First concentration parameter (alpha > 0). concentration2: Second concentration parameter (beta > 0). Returns: PriorDistribution with BETA_SCALED type. """ return cls( PriorType.BETA_SCALED, { "low": low, "high": high, "concentration1": concentration1, "concentration2": concentration2, }, )
[docs] def to_numpyro(self, name: str) -> Any: """Convert to NumPyro distribution. Args: name: Parameter name for the distribution Returns: NumPyro distribution object """ import numpyro.distributions as dist if self.prior_type == PriorType.UNIFORM: return dist.Uniform(self.params["low"], self.params["high"]) elif self.prior_type == PriorType.NORMAL: return dist.Normal(self.params["loc"], self.params["scale"]) elif self.prior_type == PriorType.TRUNCATED_NORMAL: return dist.TruncatedNormal( loc=self.params["loc"], scale=self.params["scale"], low=self.params["low"], high=self.params["high"], ) elif self.prior_type == PriorType.LOGNORMAL: return dist.LogNormal(self.params["loc"], self.params["scale"]) elif self.prior_type == PriorType.HALFNORMAL: return dist.HalfNormal(self.params["scale"]) elif self.prior_type == PriorType.EXPONENTIAL: return dist.Exponential(self.params.get("rate", 1.0)) elif self.prior_type == PriorType.BETA_SCALED: low = self.params["low"] high = self.params["high"] conc1 = self.params["concentration1"] conc2 = self.params["concentration2"] # Affine-transformed Beta: X = low + (high - low) * Beta(conc1, conc2) base = dist.Beta(conc1, conc2) return dist.TransformedDistribution( base, dist.transforms.AffineTransform(loc=low, scale=high - low), ) else: raise ValueError(f"Unknown prior type: {self.prior_type}")
[docs] @dataclass class ParameterSpace: """Complete parameter space for heterodyne model optimization. Manages parameter values, bounds, vary flags, and priors. """ values: dict[str, float] = field(default_factory=dict) vary: dict[str, bool] = field(default_factory=dict) bounds: dict[str, tuple[float, float]] = field(default_factory=dict) priors: dict[str, PriorDistribution] = field(default_factory=dict)
[docs] def __post_init__(self) -> None: """Initialize with defaults from registry.""" for name in ALL_PARAM_NAMES_WITH_SCALING: info = DEFAULT_REGISTRY[name] if name not in self.values: self.values[name] = info.default if name not in self.vary: self.vary[name] = info.vary_default if name not in self.bounds: self.bounds[name] = (info.min_bound, info.max_bound) if name not in self.priors: self.priors[name] = _default_prior(name, info)
@property def n_total(self) -> int: """Total number of parameters.""" return len(ALL_PARAM_NAMES) @property def n_varying(self) -> int: """Number of parameters that vary in optimization.""" return len(self.varying_names) @property def varying_names(self) -> list[str]: """Names of parameters that vary (physics + scaling).""" return [ name for name in ALL_PARAM_NAMES_WITH_SCALING if self.vary.get(name, False) ] @property def fixed_names(self) -> list[str]: """Names of parameters that are fixed.""" return [ name for name in ALL_PARAM_NAMES_WITH_SCALING if not self.vary.get(name, False) ] @property def varying_physics_names(self) -> list[str]: """Names of varying physics parameters (excludes scaling).""" return [name for name in ALL_PARAM_NAMES if self.vary.get(name, False)] @property def scaling_values(self) -> dict[str, float]: """Get contrast and offset values.""" return {name: self.values[name] for name in SCALING_PARAMS}
[docs] def get_initial_array(self) -> np.ndarray: """Get initial values as numpy array in canonical order. Returns: Array of shape (14,) with parameter values """ return np.array([self.values[name] for name in ALL_PARAM_NAMES])
[docs] def to_config(self) -> dict[str, Any]: """Serialize this space to a dict compatible with :meth:`from_config`. Produces the ``initial_parameters`` flat-format understood by :func:`_apply_initial_parameters`. Bounds and priors are not serialized — workers rebuild them from the registry defaults. Only values and ``active_parameters`` (vary flags) are round-tripped. Returns: Config dict that ``from_config()`` can reconstruct into an equivalent ParameterSpace (same values and varying_names). """ return { "initial_parameters": { "parameter_names": list(ALL_PARAM_NAMES_WITH_SCALING), "values": [ float(self.values[name]) for name in ALL_PARAM_NAMES_WITH_SCALING ], "active_parameters": list(self.varying_names), } }
[docs] def get_bounds_arrays(self) -> tuple[np.ndarray, np.ndarray]: """Get bounds as numpy arrays. Returns: (lower_bounds, upper_bounds) each of shape (14,) """ lower = np.array([self.bounds[name][0] for name in ALL_PARAM_NAMES]) upper = np.array([self.bounds[name][1] for name in ALL_PARAM_NAMES]) return lower, upper
[docs] def get_vary_mask(self) -> np.ndarray: """Get boolean mask for varying parameters. Returns: Boolean array of shape (14,) """ return np.array([self.vary[name] for name in ALL_PARAM_NAMES])
[docs] def array_to_dict(self, arr: np.ndarray | jnp.ndarray) -> dict[str, float]: """Convert parameter array to dictionary. Args: arr: Array of shape (14,) Returns: Dict mapping parameter names to values """ return {name: float(arr[i]) for i, name in enumerate(ALL_PARAM_NAMES)}
[docs] def update_from_dict(self, params: dict[str, float]) -> None: """Update parameter values from dictionary. Args: params: Dict with parameter names as keys Raises: ValueError: If a key doesn't match any known parameter """ for name, value in params.items(): if name not in self.values: raise ValueError( f"Unknown parameter '{name}'. " f"Valid parameters: {list(ALL_PARAM_NAMES)}" ) self.values[name] = value
[docs] def validate(self) -> list[str]: """Validate parameter space configuration. Covers all 16 parameters (14 physics + 2 scaling) so a malformed ``contrast``/``offset`` bound is caught at config-load time instead of propagating into the optimizer. Bounds are checked for every parameter — varying and fixed — because fixed parameters still drive model outputs and CMC warm-starts. The one explicit exception: parameters in ``_FIXED_ZERO_SENTINELS`` (currently ``v0``) are allowed to be 0.0 when fixed, because their min_bound is an optimizer stability floor that does not apply when the optimizer never touches the value. Returns: List of validation error messages (empty if valid) """ errors = [] for name in ALL_PARAM_NAMES_WITH_SCALING: value = self.values.get(name) bounds = self.bounds.get(name) if value is None: errors.append(f"Missing value for {name}") continue if bounds is None: errors.append(f"Missing bounds for {name}") continue low, high = bounds if low >= high: errors.append(f"{name} has inverted/degenerate bounds [{low}, {high}]") continue if not (low <= value <= high): # Allow documented zero-sentinels for fixed parameters: their # min_bound is an optimizer floor, not a physics constraint. is_fixed_zero_sentinel = ( name in _FIXED_ZERO_SENTINELS and not self.vary.get(name, False) and value == 0.0 ) if not is_fixed_zero_sentinel: errors.append(f"{name}={value} outside bounds [{low}, {high}]") return errors
[docs] def convert_to_beta_priors(self) -> None: """Convert all TruncatedNormal priors to BetaScaled priors. For each parameter whose prior is TRUNCATED_NORMAL, this method computes equivalent Beta concentration parameters via the method of moments and replaces the prior in-place with a BETA_SCALED distribution over the same bounds. Parameters with other prior types are left unchanged. """ for name, prior in self.priors.items(): if prior.prior_type != PriorType.TRUNCATED_NORMAL: continue loc = prior.params["loc"] scale = prior.params["scale"] low, high = self.bounds[name] conc1, conc2 = _compute_beta_concentrations(loc, scale, low, high) self.priors[name] = PriorDistribution.beta_scaled( low, high, conc1, conc2, ) logger.debug( "Converted %s prior: TruncatedNormal(loc=%.4g, scale=%.4g) " "-> BetaScaled(conc1=%.4g, conc2=%.4g) on [%.4g, %.4g]", name, loc, scale, conc1, conc2, low, high, )
[docs] def with_single_angle_stabilization(self) -> ParameterSpace: """Return a new ParameterSpace with tightened bounds for single-angle analysis. Narrows contrast bounds to [value-0.2, value+0.2] and offset bounds to [value-0.1, value+0.1], clamped to the original bounds. Returns: A new ParameterSpace with tightened scaling bounds. """ new = ParameterSpace( values=deepcopy(self.values), vary=deepcopy(self.vary), bounds=deepcopy(self.bounds), priors=deepcopy(self.priors), ) # Tighten contrast bounds if "contrast" in new.bounds: low, high = new.bounds["contrast"] val = new.values["contrast"] new_low = max(low, val - 0.2) new_high = min(high, val + 0.2) new.bounds["contrast"] = (new_low, new_high) logger.debug( "Single-angle stabilization: contrast bounds [%.4g, %.4g] -> [%.4g, %.4g]", low, high, new_low, new_high, ) # Tighten offset bounds if "offset" in new.bounds: low, high = new.bounds["offset"] val = new.values["offset"] new_low = max(low, val - 0.1) new_high = min(high, val + 0.1) new.bounds["offset"] = (new_low, new_high) logger.debug( "Single-angle stabilization: offset bounds [%.4g, %.4g] -> [%.4g, %.4g]", low, high, new_low, new_high, ) return new
[docs] @classmethod def from_config(cls, config: dict[str, Any]) -> ParameterSpace: """Create ParameterSpace from configuration dictionary. Supports two input formats (homodyne parity): 1. **Grouped format** (preferred) — ``parameters.{group}.{param}``:: parameters: reference: D0_ref: value: 5000.0 min: 200.0 max: 50000.0 vary: true 2. **Flat format** — ``initial_parameters.parameter_names`` + ``values``:: initial_parameters: parameter_names: [D0_ref, alpha_ref] values: [5000.0, 0.5] active_parameters: [D0_ref] # optional vary subset When both are present, grouped format takes precedence (it is applied second so its values overwrite flat-format values). Args: config: Config dict with 'parameters' and/or 'initial_parameters' sections. Returns: Configured ParameterSpace """ space = cls() # --- Flat format: initial_parameters (homodyne parity) --------------- _apply_initial_parameters(space, config) # --- Grouped format: parameters.{group}.{param} (primary) ----------- params_config = config.get("parameters", {}) group_map = { "reference": ["D0_ref", "alpha_ref", "D_offset_ref"], "sample": ["D0_sample", "alpha_sample", "D_offset_sample"], "velocity": ["v0", "beta", "v_offset"], "fraction": ["f0", "f1", "f2", "f3"], "angle": ["phi0"], "scaling": ["contrast", "offset"], } for group_name, param_names in group_map.items(): group_config = params_config.get(group_name, {}) # Check for unknown keys in this group known_params = set(param_names) for ck in group_config: if ck not in known_params: raise ValueError( f"Unknown parameter key '{ck}' in group '{group_name}'. " f"Valid keys: {param_names}" ) for param_name in param_names: # Direct key match only — no substring matching if param_name not in group_config: continue pconfig = group_config[param_name] if isinstance(pconfig, dict): reg_info = DEFAULT_REGISTRY[param_name] if "value" in pconfig: new_val = pconfig["value"] if new_val != reg_info.default: logger.debug( "Config overrides %s value: %.6g -> %.6g", param_name, reg_info.default, new_val, ) space.values[param_name] = new_val has_min = "min" in pconfig has_max = "max" in pconfig if has_min ^ has_max: raise ValueError( f"Parameter '{param_name}' has only one of 'min'/'max' " f"set; both bounds must be specified together." ) if has_min and has_max: new_bounds = (pconfig["min"], pconfig["max"]) if new_bounds[0] >= new_bounds[1]: raise ValueError( f"Parameter '{param_name}' has inverted/degenerate " f"bounds: min={new_bounds[0]} >= max={new_bounds[1]}" ) if ( new_bounds[0] != reg_info.min_bound or new_bounds[1] != reg_info.max_bound ): logger.debug( "Config overrides %s bounds: [%.4g, %.4g] -> [%.4g, %.4g]", param_name, reg_info.min_bound, reg_info.max_bound, new_bounds[0], new_bounds[1], ) space.bounds[param_name] = new_bounds if "vary" in pconfig: new_vary = pconfig["vary"] if new_vary != reg_info.vary_default: logger.debug( "Config overrides %s vary: %s -> %s", param_name, reg_info.vary_default, new_vary, ) space.vary[param_name] = new_vary if "prior" in pconfig: prior_type_str = pconfig["prior"] prior_params = pconfig.get("prior_params", {}) space.priors[param_name] = _build_prior( param_name, prior_type_str, prior_params, space.bounds[param_name], ) space._config_dict: dict[str, Any] = config # type: ignore[attr-defined] # Validate the assembled space — catches inverted/degenerate bounds and # values outside bounds at config-load time instead of surfacing as # cryptic optimizer failures later. validation_errors = space.validate() if validation_errors: raise ValueError( "Invalid parameter configuration: " + "; ".join(validation_errors) ) return space
def _apply_initial_parameters(space: ParameterSpace, config: dict[str, Any]) -> None: """Apply ``initial_parameters`` flat-format values to *space*. Homodyne parity: supports:: initial_parameters: parameter_names: [D0_ref, alpha_ref, ...] values: [5000.0, 0.5, ...] active_parameters: [D0_ref] # optional: only these vary Args: space: ParameterSpace to modify in-place. config: Full configuration dictionary. """ from heterodyne.config.types import PARAMETER_NAME_MAPPING initial = config.get("initial_parameters", {}) if not initial or not isinstance(initial, dict): return param_names_raw = initial.get("parameter_names") param_values = initial.get("values") if ( not param_names_raw or not isinstance(param_names_raw, list) or param_values is None or not isinstance(param_values, list) ): return # Apply name mapping for legacy/alias names param_names = [PARAMETER_NAME_MAPPING.get(str(n), str(n)) for n in param_names_raw] if len(param_names) != len(param_values): logger.warning( "initial_parameters: parameter_names (%d) and values (%d) length mismatch; " "skipping flat-format override", len(param_names), len(param_values), ) return for name, value in zip(param_names, param_values, strict=True): if name in space.values: space.values[name] = float(value) logger.debug( "initial_parameters: set %s = %.6g (flat-format override)", name, value ) else: logger.warning("initial_parameters: unknown parameter '%s', skipping", name) # active_parameters: if provided, only these parameters vary active_raw = initial.get("active_parameters") if active_raw and isinstance(active_raw, list): active_names = {PARAMETER_NAME_MAPPING.get(str(n), str(n)) for n in active_raw} for name in ALL_PARAM_NAMES_WITH_SCALING: if name in active_names: space.vary[name] = True elif name in space.vary: space.vary[name] = False logger.debug( "initial_parameters: active_parameters set %d params to vary", len(active_names), ) # Default TruncatedNormal prior specifications: (loc, scale) # All parameters use TruncatedNormal priors truncated to their registry bounds. # IMPORTANT: must stay in sync with parameter_registry.py prior_mean/prior_std # (CLAUDE.md rule #9 — dual prior system). See tests/unit/test_prior_sanity.py # for the contract test that enforces this. _DEFAULT_PRIOR_SPECS: dict[str, tuple[float, float]] = { "D0_ref": (1e4, 1e4), # widened from 5e3 → 1e4 (see registry comment) "alpha_ref": (0.0, 1.0), "D_offset_ref": (0.0, 1e3), "D0_sample": (1e4, 1e4), # widened from 5e3 → 1e4 (see registry comment) "alpha_sample": (0.0, 1.0), "D_offset_sample": (0.0, 1e3), "v0": (1e3, 1000.0), # widened from 500 → 1000 (see registry comment) "beta": (0.0, 1.0), "v_offset": (0.0, 25.0), "f0": (0.5, 0.25), "f1": (0.0, 5.0), "f2": (0.0, 1e3), "f3": (0.0, 0.5), "phi0": (0.0, 5.0), "contrast": (0.5, 0.25), "offset": (1.0, 0.25), } def _compute_beta_concentrations( mean: float, std: float, low: float, high: float, ) -> tuple[float, float]: """Compute Beta concentration parameters from desired mean and std on [low, high]. Uses the method of moments to find (alpha, beta) such that a Beta(alpha, beta) distribution scaled to [low, high] has the specified mean and standard deviation. The standard Beta(alpha, beta) on [0, 1] has: mu_01 = alpha / (alpha + beta) var_01 = alpha * beta / ((alpha + beta)^2 * (alpha + beta + 1)) We map the desired mean/std from [low, high] to [0, 1]: mu_01 = (mean - low) / (high - low) var_01 = (std / (high - low))^2 Then solve for alpha, beta via method of moments: alpha = mu_01 * ((mu_01 * (1 - mu_01)) / var_01 - 1) beta = (1 - mu_01) * ((mu_01 * (1 - mu_01)) / var_01 - 1) Args: mean: Desired mean on [low, high]. std: Desired standard deviation on [low, high]. low: Lower bound. high: Upper bound. Returns: Tuple (concentration1, concentration2) both > 0. Raises: ValueError: If the mean is outside [low, high] or std is too large for a valid Beta distribution. """ if high <= low: raise ValueError(f"high ({high}) must be > low ({low})") if not (low <= mean <= high): raise ValueError(f"mean ({mean}) must be in [{low}, {high}]") range_width = high - low mu_01 = (mean - low) / range_width var_01 = (std / range_width) ** 2 # Variance of Beta on [0,1] must be < mu*(1-mu) max_var = mu_01 * (1.0 - mu_01) if var_01 >= max_var: raise ValueError( f"std={std} is too large for Beta on [{low}, {high}] with mean={mean}. " f"Max std ~ {(max_var**0.5) * range_width:.4e}" ) # Method of moments common = mu_01 * (1.0 - mu_01) / var_01 - 1.0 alpha = mu_01 * common beta_param = (1.0 - mu_01) * common # Floor to avoid degenerate distributions alpha = max(alpha, 0.01) beta_param = max(beta_param, 0.01) return alpha, beta_param def _default_prior( name: str, info: Any, ) -> PriorDistribution: """Build the default TruncatedNormal prior for a parameter. Args: name: Parameter name. info: ParameterInfo from the registry. Returns: TruncatedNormal prior distribution. """ if name in _DEFAULT_PRIOR_SPECS: loc, scale = _DEFAULT_PRIOR_SPECS[name] return PriorDistribution.truncated_normal( loc=loc, scale=scale, low=info.min_bound, high=info.max_bound, ) # Fallback for any unspecified parameter return PriorDistribution.uniform(info.min_bound, info.max_bound) def _build_prior( name: str, prior_type_str: str, prior_params: dict[str, float], bounds: tuple[float, float], ) -> PriorDistribution: """Build a PriorDistribution from config strings. Args: name: Parameter name (for error messages) prior_type_str: One of "uniform", "normal", "lognormal", "halfnormal", "exponential" prior_params: Distribution-specific parameters (e.g. {"loc": 0, "scale": 1} for normal) bounds: (low, high) bounds, used as fallback for uniform Returns: Configured PriorDistribution """ try: prior_type = PriorType(prior_type_str) except ValueError: valid = [pt.value for pt in PriorType] raise ValueError( f"Unknown prior type '{prior_type_str}' for parameter '{name}'. " f"Valid types: {valid}" ) from None if prior_type == PriorType.UNIFORM: low = prior_params.get("low", bounds[0]) high = prior_params.get("high", bounds[1]) return PriorDistribution.uniform(low, high) elif prior_type == PriorType.NORMAL: loc = prior_params.get("loc", (bounds[0] + bounds[1]) / 2) scale = prior_params.get("scale", (bounds[1] - bounds[0]) / 4) return PriorDistribution.normal(loc, scale) elif prior_type == PriorType.TRUNCATED_NORMAL: loc = prior_params.get("loc", (bounds[0] + bounds[1]) / 2) scale = prior_params.get("scale", (bounds[1] - bounds[0]) / 4) low = prior_params.get("low", bounds[0]) high = prior_params.get("high", bounds[1]) return PriorDistribution.truncated_normal(loc, scale, low, high) elif prior_type == PriorType.LOGNORMAL: loc = prior_params.get("loc", 0.0) scale = prior_params.get("scale", 1.0) return PriorDistribution.lognormal(loc, scale) elif prior_type == PriorType.HALFNORMAL: scale = prior_params.get("scale", 1.0) return PriorDistribution.halfnormal(scale) elif prior_type == PriorType.EXPONENTIAL: return PriorDistribution(PriorType.EXPONENTIAL, prior_params) elif prior_type == PriorType.BETA_SCALED: low = prior_params.get("low", bounds[0]) high = prior_params.get("high", bounds[1]) conc1 = prior_params.get("concentration1", 2.0) conc2 = prior_params.get("concentration2", 2.0) return PriorDistribution.beta_scaled(low, high, conc1, conc2) else: raise ValueError(f"Unhandled prior type: {prior_type}")
[docs] def clamp_to_open_interval( value: float, low: float, high: float, epsilon: float = 1e-6, ) -> float: """Clamp value to the open interval (low+epsilon, high-epsilon). Useful for Beta distribution parameters that must be strictly within their support bounds. Args: value: Value to clamp. low: Lower bound of the closed interval. high: Upper bound of the closed interval. epsilon: Margin to inset from the bounds. Returns: Clamped value in (low+epsilon, high-epsilon). """ return max(low + epsilon, min(value, high - epsilon))