Source code for heterodyne.config.parameter_manager

"""Parameter manager for heterodyne model optimization."""

from __future__ import annotations

import copy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, TypedDict

import numpy as np

from heterodyne.config.parameter_names import (
    ALL_PARAM_NAMES,
    ALL_PARAM_NAMES_WITH_SCALING,
    PARAM_GROUPS,
    SCALING_PARAMS,
)
from heterodyne.config.parameter_space import ParameterSpace
from heterodyne.config.physics_validators import ValidationResult, validate_parameters
from heterodyne.utils.logging import get_logger

if TYPE_CHECKING:
    import jax.numpy as jnp

logger = get_logger(__name__)


[docs] class BoundDict(TypedDict): """Bound specification for a single parameter.""" name: str min: float max: float type: str
[docs] @dataclass class ParameterManager: """Manages parameter values, constraints, and transformations. Provides the bridge between configuration and optimization by: - Managing which parameters vary vs are fixed - Handling parameter transformations (e.g., bounded -> unbounded) - Constructing full parameter arrays from varying subsets - Validating parameter values against physics constraints Performance caching is enabled by default for repeated bound and active-parameter queries. """ space: ParameterSpace = field(default_factory=ParameterSpace) # Performance caching — populated lazily via __post_init__ _bounds_cache: dict[frozenset[str], list[BoundDict]] = field( default_factory=dict, init=False, repr=False ) _active_params_cache: list[str] | None = field(default=None, init=False, repr=False) _cache_enabled: bool = field(default=True, init=False, repr=False) # B006: cached index lists (invalidated by set_vary) _varying_indices_cache: list[int] | None = field( default=None, init=False, repr=False ) _fixed_indices_cache: list[int] | None = field(default=None, init=False, repr=False) _varying_names_cache: list[str] | None = field(default=None, init=False, repr=False) # B007: cached full-values array (invalidated by update_values) _full_values_cache: np.ndarray | None = field(default=None, init=False, repr=False) # Frozen snapshot of config-specified initial values — set once at construction, # never mutated by update_values/set_params. get_initial_values() reads from here # so that each phi-angle optimization starts from config values regardless of # what a previous fit stored in space.values. _initial_values_snapshot: dict[str, float] = field( default_factory=dict, init=False, repr=False )
[docs] def __post_init__(self) -> None: """Build default bounds lookup from the registry, then merge config overrides.""" from heterodyne.config.parameter_registry import DEFAULT_REGISTRY self._default_bounds: dict[str, BoundDict] = {} for name in ALL_PARAM_NAMES_WITH_SCALING: info = DEFAULT_REGISTRY[name] self._default_bounds[name] = BoundDict( name=name, min=info.min_bound, max=info.max_bound, type="TruncatedNormal", ) # Sync _default_bounds with config-overridden bounds from ParameterSpace # so that both sources agree (homodyne parity with _load_config_bounds). self._sync_bounds_from_space() # Freeze initial values so get_initial_values() always returns the config # starting point regardless of model.set_params() calls between phi-angle fits. self._initial_values_snapshot = copy.deepcopy(self.space.values)
def _sync_bounds_from_space(self) -> None: """Merge config-overridden bounds from ParameterSpace into _default_bounds. Homodyne parity: equivalent to ``_load_config_bounds()`` which reads ``parameter_space.bounds`` from the config dict and calls ``.update()`` on the in-memory bounds. Here the bounds already live in ``self.space.bounds`` (populated by ``ParameterSpace.from_config``), so we just copy them over. """ from heterodyne.config.parameter_registry import DEFAULT_REGISTRY for name in ALL_PARAM_NAMES_WITH_SCALING: lo, hi = self.space.bounds.get( name, (DEFAULT_REGISTRY[name].min_bound, DEFAULT_REGISTRY[name].max_bound), ) if name in self._default_bounds: reg = DEFAULT_REGISTRY[name] if lo != reg.min_bound or hi != reg.max_bound: logger.debug( "Config overrides bounds for %s: [%.4g, %.4g] -> [%.4g, %.4g]", name, reg.min_bound, reg.max_bound, lo, hi, ) self._default_bounds[name]["min"] = lo self._default_bounds[name]["max"] = hi # ------------------------------------------------------------------ # Core existing API # ------------------------------------------------------------------ @property def n_params(self) -> int: """Total number of physics model parameters (14).""" return len(ALL_PARAM_NAMES) @property def n_varying(self) -> int: """Number of physics parameters that vary in optimization.""" return len(self.varying_names) @property def varying_names(self) -> list[str]: """Names of varying physics parameters (excludes scaling).""" if self._varying_names_cache is None: self._varying_names_cache = self.space.varying_physics_names return list(self._varying_names_cache) @property def varying_indices(self) -> list[int]: """Indices of varying parameters in the 14-element physics array.""" if self._varying_indices_cache is None: self._varying_indices_cache = [ i for i, name in enumerate(ALL_PARAM_NAMES) if self.space.vary.get(name, False) ] return list(self._varying_indices_cache) @property def fixed_indices(self) -> list[int]: """Indices of fixed parameters in the 14-element physics array.""" if self._fixed_indices_cache is None: self._fixed_indices_cache = [ i for i, name in enumerate(ALL_PARAM_NAMES) if not self.space.vary.get(name, False) ] return list(self._fixed_indices_cache)
[docs] def get_initial_values(self) -> np.ndarray: """Get initial parameter values for optimization. Returns the config-specified starting point, not the current fitted state. Reads from the frozen snapshot set at construction time so that repeated calls (e.g. across multi-angle loops) always return the same config values even after model.set_params() has mutated space.values. Returns: Array of shape (n_varying,) with initial values for varying params. """ full = np.array( [ self._initial_values_snapshot.get( name, self.space.values.get(name, 0.0) ) for name in ALL_PARAM_NAMES ] ) return full[self.varying_indices]
[docs] def get_full_values(self) -> np.ndarray: """Get all 14 parameter values. Returns a read-only cached array (``writeable=False``). Use ``.copy()`` if mutation is required. Returns: Array of shape (14,). """ if self._full_values_cache is None: arr = self.space.get_initial_array() arr.flags.writeable = False self._full_values_cache = arr return self._full_values_cache
[docs] def get_bounds(self) -> tuple[np.ndarray, np.ndarray]: """Get bounds for varying physics parameters. Returns: (lower, upper) each of shape (n_varying,). """ lower_full, upper_full = self.space.get_bounds_arrays() idx = self.varying_indices return lower_full[idx], upper_full[idx]
[docs] def expand_varying_to_full( self, varying_params: np.ndarray | jnp.ndarray, ) -> np.ndarray: """Expand varying parameters to full 14-parameter array. Fixed parameters are filled from stored values. Args: varying_params: Array of shape (n_varying,). Returns: Array of shape (14,). """ full = self.get_full_values().copy() for i, idx in enumerate(self.varying_indices): full[idx] = float(varying_params[i]) return full
[docs] def extract_varying(self, full_params: np.ndarray | jnp.ndarray) -> np.ndarray: """Extract varying parameters from full array. Args: full_params: Array of shape (14,). Returns: Array of shape (n_varying,). """ return np.array([full_params[i] for i in self.varying_indices])
[docs] def update_values(self, params: np.ndarray | dict[str, float]) -> None: """Update stored parameter values. Args: params: Either array of shape (14,) or dict with param names. """ if isinstance(params, dict): self.space.update_from_dict(params) else: params_dict = self.space.array_to_dict(np.asarray(params)) self.space.update_from_dict(params_dict) # Invalidate full-values cache — values have changed self._full_values_cache = None
[docs] def get_parameter_dict(self) -> dict[str, float]: """Get current parameter values as dictionary.""" return dict(self.space.values)
[docs] def set_vary(self, name: str, vary: bool) -> None: """Set whether a parameter varies in optimization. Invalidates relevant caches. Args: name: Parameter name (physics or scaling). vary: Whether to vary this parameter. """ if name not in ALL_PARAM_NAMES_WITH_SCALING: raise ValueError(f"Unknown parameter: {name}") self.space.vary[name] = vary # Varying status change affects active/fixed and index caches self._active_params_cache = None self._varying_names_cache = None self._varying_indices_cache = None self._fixed_indices_cache = None
[docs] def set_bounds(self, name: str, lower: float, upper: float) -> None: """Set bounds for a parameter. Invalidates the bounds cache for any query that includes this parameter. Args: name: Parameter name (physics or scaling). lower: Lower bound. upper: Upper bound. """ if name not in ALL_PARAM_NAMES_WITH_SCALING: raise ValueError(f"Unknown parameter: {name}") self.space.bounds[name] = (lower, upper) # Update the local default_bounds mirror and flush cache if name in self._default_bounds: self._default_bounds[name]["min"] = lower self._default_bounds[name]["max"] = upper self._bounds_cache.clear()
[docs] def validate_physics(self, params: np.ndarray | None = None) -> list[str]: """Validate parameters against physics constraints. Args: params: Full parameter array of shape (14,), or None to use stored values. Returns: List of violation messages (empty if valid). """ if params is None: params = self.get_full_values() result = validate_parameters(params) return result.errors + result.warnings
[docs] @classmethod def from_config(cls, config: dict[str, Any]) -> ParameterManager: """Create ParameterManager from configuration dictionary. Args: config: Full configuration dict. Returns: Configured ParameterManager. """ space = ParameterSpace.from_config(config) return cls(space=space)
[docs] def get_group_values(self, group: str) -> dict[str, float]: """Get parameter values for a specific group. Args: group: Group name ('reference', 'sample', 'velocity', 'fraction', 'angle', 'scaling'). Returns: Dict mapping parameter names to values. """ if group not in PARAM_GROUPS: raise ValueError(f"Unknown group: {group}") return {name: self.space.values[name] for name in PARAM_GROUPS[group]}
# ------------------------------------------------------------------ # New API: bounds queries # ------------------------------------------------------------------
[docs] def get_parameter_bounds( self, parameter_names: list[str] | None = None, ) -> list[BoundDict]: """Get parameter bounds configuration with caching. Args: parameter_names: Names of parameters to retrieve bounds for. If None, returns bounds for all 16 parameters (14 physics + 2 scaling) in canonical order. Returns: List of BoundDict entries with keys 'name', 'min', 'max', 'type'. Notes: Results are cached per unique (sorted) parameter set. Cache is invalidated automatically by set_bounds(). """ if parameter_names is None: parameter_names = list(ALL_PARAM_NAMES_WITH_SCALING) cache_key = frozenset(parameter_names) if self._cache_enabled and cache_key in self._bounds_cache: logger.debug( "Returning cached bounds for %d parameters", len(parameter_names) ) return [b.copy() for b in self._bounds_cache[cache_key]] # type: ignore[return-value] bounds_list: list[BoundDict] = [] for name in parameter_names: if name in self._default_bounds: # Always reflect live space.bounds (may differ from registry defaults # if set_bounds() was called) lo, hi = self.space.bounds.get( name, ( self._default_bounds[name]["min"], self._default_bounds[name]["max"], ), ) bounds_list.append( BoundDict(name=name, min=lo, max=hi, type="TruncatedNormal") ) else: logger.warning( "Unknown parameter '%s', using default bounds [0.0, 1.0]", name ) bounds_list.append( BoundDict(name=name, min=0.0, max=1.0, type="TruncatedNormal") ) if self._cache_enabled: self._bounds_cache[cache_key] = [b.copy() for b in bounds_list] # type: ignore[misc] return bounds_list
[docs] def get_bounds_as_tuples( self, parameter_names: list[str] | None = None, ) -> list[tuple[float, float]]: """Get parameter bounds as a list of (min, max) tuples. Convenience method for compatibility with optimization code that expects the scipy-style bounds format. Args: parameter_names: Parameter names. If None, uses all 16 parameters. Returns: List of (min, max) tuples, one per parameter. """ return [ (b["min"], b["max"]) for b in self.get_parameter_bounds(parameter_names) ]
[docs] def get_bounds_as_arrays( self, parameter_names: list[str] | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Get parameter bounds as separate lower and upper numpy arrays. Convenience method for NLSQ and JAX optimizers that consume separate lower/upper bound arrays. Args: parameter_names: Parameter names. If None, uses all 16 parameters. Returns: (lower_bounds, upper_bounds) as numpy arrays of shape (n_params,). """ bd = self.get_parameter_bounds(parameter_names) lower = np.array([b["min"] for b in bd]) upper = np.array([b["max"] for b in bd]) return lower, upper
# ------------------------------------------------------------------ # New API: active / fixed / optimizable parameter queries # ------------------------------------------------------------------
[docs] def get_active_parameters(self) -> list[str]: """Get physics parameter names that are marked as varying. Returns the 14-element physics parameters (excludes scaling) whose ``vary`` flag is True in the current ParameterSpace. Falls back to all 14 physics parameters if the space has no explicit vary flags set. Results are cached; call set_vary() to invalidate automatically. Returns: List of varying physics parameter names in canonical order. """ if self._cache_enabled and self._active_params_cache is not None: logger.debug("Returning cached active parameters") return list(self._active_params_cache) active = self.space.varying_physics_names # Fall back to all physics params when none are flagged as varying # (e.g. a freshly constructed manager with all vary=False defaults) if not active: active = list(ALL_PARAM_NAMES) if self._cache_enabled: self._active_params_cache = list(active) return active
[docs] def get_all_parameter_names(self) -> list[str]: """Get all parameter names: scaling parameters first, then physics. Returns: List of 16 names (contrast, offset, then the 14 physics params) in canonical order. """ return list(SCALING_PARAMS) + list(ALL_PARAM_NAMES)
[docs] def get_effective_parameter_count(self) -> int: """Number of active (varying) physics parameters, excluding scaling. Returns: Count of physics parameters whose vary flag is True. """ return len(self.get_active_parameters())
[docs] def get_total_parameter_count(self) -> int: """Total parameter count including both scaling and physics parameters. Returns: Always 16 for the heterodyne model (14 physics + 2 scaling). """ return len(ALL_PARAM_NAMES_WITH_SCALING)
[docs] def get_fixed_parameters(self) -> dict[str, float]: """Return physics parameters that are held fixed during optimization. A parameter is considered fixed when its ``vary`` flag is False in the ParameterSpace. Scaling parameters (contrast, offset) are excluded from this result — use get_parameter_dict() to access their values. Returns: Dict mapping fixed physics parameter name to its current value. """ return { name: self.space.values[name] for name in ALL_PARAM_NAMES if not self.space.vary.get(name, False) }
[docs] def is_parameter_active(self, param_name: str) -> bool: """Check whether a physics parameter is active (vary=True). Args: param_name: Physics parameter name to check. Must be one of the 14 physics parameters; scaling names always return False. Returns: True if the parameter's vary flag is True, False otherwise. """ if param_name not in ALL_PARAM_NAMES: return False return bool(self.space.vary.get(param_name, False))
[docs] def get_optimizable_parameters(self) -> list[str]: """Return physics parameters that should be optimized. Equivalent to active parameters (vary=True). Scaling parameters are handled separately and are not included. Returns: List of physics parameter names with vary=True, in canonical order. """ return self.get_active_parameters()
# ------------------------------------------------------------------ # New API: physics constraint validation with severity # ------------------------------------------------------------------
[docs] def validate_physical_constraints( self, params: dict[str, float] | np.ndarray | None = None, severity_level: str = "warning", ) -> ValidationResult: """Validate physics-based constraints beyond simple bound checking. Checks for physically impossible or unusual parameter combinations based on the heterodyne two-component scattering model. Args: params: Parameter dict, array of shape (14,), or None to use stored values. Dict keys must be physics parameter names. severity_level: Minimum severity to include in the result. One of: - ``"error"`` — physically impossible values only. - ``"warning"`` — unusual but possible values (default). - ``"info"`` — all noteworthy observations. Currently the heterodyne validator does not distinguish severity internally; this argument is accepted for API parity with homodyne and is reserved for future use. Returns: ValidationResult with ``is_valid``, ``errors``, and ``warnings``. """ if params is None: arr = self.get_full_values() elif isinstance(params, dict): arr = self.get_full_values().copy() param_dict_full = self.space.array_to_dict(arr) unknown_keys = [k for k in params if k not in param_dict_full] if unknown_keys: logger.warning( "validate_physical_constraints: ignoring unknown " "parameter override keys %s (valid keys: %s). " "Misspellings will silently revert to stored values.", unknown_keys, sorted(param_dict_full.keys()), ) param_dict_full.update( {k: v for k, v in params.items() if k in param_dict_full} ) arr = np.array([param_dict_full[name] for name in ALL_PARAM_NAMES]) else: arr = np.asarray(params, dtype=float) result = validate_parameters(arr) if severity_level == "error": # Suppress warnings, keep only errors return ValidationResult( is_valid=len(result.errors) == 0, errors=result.errors, warnings=[], ) return result
# ------------------------------------------------------------------ # Dunder # ------------------------------------------------------------------
[docs] def __repr__(self) -> str: """Concise string representation of manager state.""" n_active = len(self.get_active_parameters()) n_fixed = len(self.get_fixed_parameters()) n_varying_scaling = sum( 1 for name in SCALING_PARAMS if self.space.vary.get(name, False) ) return ( f"ParameterManager(" f"n_physics={self.n_params}, " f"active={n_active}, " f"fixed={n_fixed}, " f"scaling_varying={n_varying_scaling}, " f"total={self.get_total_parameter_count()})" )