Source code for heterodyne.config.manager

"""Configuration manager for heterodyne analysis."""

from __future__ import annotations

import copy
import logging
from pathlib import Path
from typing import Any, cast

import yaml

from heterodyne.utils.path_validation import validate_file_exists

logger = logging.getLogger(__name__)

_ALLOWED_OPTIMIZATION_METHODS = {"nlsq", "cmc"}


[docs] class ConfigurationError(Exception): """Raised when configuration is invalid."""
[docs] class ConfigManager: """Manager for heterodyne analysis configuration. Handles loading, validation, and access to configuration settings. """
[docs] def __init__(self, config: dict[str, Any]) -> None: """Initialize with configuration dictionary. Args: config: Configuration dictionary """ self._config = copy.deepcopy(config) self._normalize_schema() self._validate()
def _normalize_schema(self) -> None: """Normalize deprecated configuration keys to canonical names.""" from heterodyne.config.types import PARAMETER_NAME_MAPPING # Normalize parameter names in all parameter groups params: dict[str, Any] = self._config.get("parameters", {}) for group_name in list(params.keys()): group_config = params[group_name] if not isinstance(group_config, dict): continue normalized: dict[str, Any] = {} seen_sources: dict[str, str] = {} for key, value in group_config.items(): canonical: str = PARAMETER_NAME_MAPPING.get(str(key), str(key)) if canonical in normalized and seen_sources.get(canonical) != key: raise ConfigurationError( f"Parameter group '{group_name}' contains both " f"'{seen_sources[canonical]}' and '{key}' which both " f"resolve to canonical key '{canonical}'. " f"Remove one to avoid silent value loss." ) if canonical != key: logger.debug( "Normalized parameter key '%s' -> '%s'", key, canonical ) normalized[canonical] = value seen_sources[canonical] = key params[group_name] = normalized # Normalize legacy temporal/scattering sections into analyzer_parameters self._normalize_analyzer_parameters() # Normalize CMC config keys cmc = self._config.get("optimization", {}).get("cmc", {}) if isinstance(cmc, dict): normalized_cmc: dict[str, Any] = {} cmc_seen: dict[str, str] = {} for key, value in cmc.items(): cmc_canonical: str = PARAMETER_NAME_MAPPING.get(str(key), str(key)) if ( cmc_canonical in normalized_cmc and cmc_seen.get(cmc_canonical) != key ): raise ConfigurationError( f"CMC config contains both '{cmc_seen[cmc_canonical]}' " f"and '{key}' which both resolve to canonical key " f"'{cmc_canonical}'. Remove one to avoid silent value loss." ) if cmc_canonical != key: logger.debug("Normalized CMC key '%s' -> '%s'", key, cmc_canonical) normalized_cmc[cmc_canonical] = value cmc_seen[cmc_canonical] = key opt = self._config.get("optimization") if isinstance(opt, dict) and "cmc" in opt: opt["cmc"] = normalized_cmc def _normalize_analyzer_parameters(self) -> None: """Merge legacy ``temporal``/``scattering`` sections into ``analyzer_parameters``. Supports three config styles: 1. **New** — ``analyzer_parameters`` with dt, start_frame, end_frame, scattering, and geometry sub-keys. 2. **Legacy** — separate ``temporal`` and ``scattering`` top-level sections. 3. **Mixed** — ``analyzer_parameters`` exists but legacy sections also present; legacy values are used as fallbacks only. After normalization, ``temporal`` and ``scattering`` top-level keys are synthesized from ``analyzer_parameters`` so downstream code that reads the raw config dict keeps working during migration. """ ap = self._config.get("analyzer_parameters", {}) temporal = self._config.get("temporal", {}) scattering = self._config.get("scattering", {}) if not ap and not temporal and not scattering: # Nothing to normalize — will fail validation later return if not ap and (temporal or scattering): logger.info( "Migrating legacy 'temporal'/'scattering' sections " "into 'analyzer_parameters'" ) # --- Build canonical analyzer_parameters -------------------------- # Start from a deep copy of the existing analyzer_parameters so any # unknown user-supplied fields (e.g. temperature, beamline metadata) # survive migration instead of being silently dropped by the # whitelist below. merged: dict[str, Any] = copy.deepcopy(ap) if isinstance(ap, dict) else {} # dt: top-level in analyzer_parameters (parity with homodyne) merged["dt"] = ap.get("dt", temporal.get("dt", 1.0)) # Frame range: prefer start_frame/end_frame (1-indexed, inclusive) if "start_frame" in ap: merged["start_frame"] = int(ap["start_frame"]) elif "t_start" in temporal: # Legacy: t_start is 0-indexed → start_frame is 1-indexed merged["start_frame"] = int(temporal["t_start"]) + 1 else: merged["start_frame"] = 1 if "end_frame" in ap: merged["end_frame"] = int(ap["end_frame"]) elif "time_length" in temporal: # Legacy: end_frame = t_start + time_length (inclusive) t_start = int(temporal.get("t_start", 0)) merged["end_frame"] = t_start + int(temporal["time_length"]) else: merged["end_frame"] = 1000 # Scattering sub-section ap_scat = ap.get("scattering", {}) merged_scat: dict[str, Any] = {} merged_scat["wavevector_q"] = ap_scat.get( "wavevector_q", scattering.get("wavevector_q", 0.01) ) # phi_angles (optional) phi = ap_scat.get("phi_angles", scattering.get("phi_angles")) if phi is not None: merged_scat["phi_angles"] = phi merged["scattering"] = merged_scat # Geometry sub-section (new — parity with homodyne) ap_geom = ap.get("geometry", {}) if ap_geom: merged["geometry"] = dict(ap_geom) self._config["analyzer_parameters"] = merged # --- Synthesize legacy keys for downstream raw-config readers ----- start_frame = merged["start_frame"] end_frame = merged["end_frame"] t_start = start_frame - 1 # 1-indexed → 0-indexed time_length = end_frame - t_start # inclusive range length self._config["temporal"] = { "dt": merged["dt"], "time_length": time_length, "t_start": t_start, } self._config["scattering"] = { "wavevector_q": merged["scattering"]["wavevector_q"], } if "phi_angles" in merged["scattering"]: self._config["scattering"]["phi_angles"] = merged["scattering"][ "phi_angles" ] def _validate(self) -> None: """Validate configuration structure.""" required_sections = [ "experimental_data", "analyzer_parameters", "parameters", ] missing = [s for s in required_sections if s not in self._config] if missing: raise ConfigurationError(f"Missing required sections: {missing}") # Warn if optimization section is absent; validate method if present if "optimization" not in self._config: logger.warning( "Configuration has no 'optimization' section; " "defaults will be used (method='nlsq')" ) else: method = self._config["optimization"].get("method") if method is not None and method not in _ALLOWED_OPTIMIZATION_METHODS: raise ConfigurationError( f"Invalid optimization method '{method}'. " f"Allowed values: {sorted(_ALLOWED_OPTIMIZATION_METHODS)}" ) cmc_section = self._config["optimization"].get("cmc") if isinstance(cmc_section, dict): cmc_errors = self._validate_cmc_config(cmc_section) if cmc_errors: raise ConfigurationError( "Invalid CMC configuration: " + "; ".join(cmc_errors) ) # Validate config_version if present self._validate_config_version() def _validate_config_version(self) -> None: """Warn if config_version doesn't match package version.""" metadata = self._config.get("metadata", {}) config_version = metadata.get("config_version") if config_version is None: return try: from heterodyne._version import __version__ # Compare major.minor only (patch mismatches are fine) cv_parts = str(config_version).split(".")[:2] pkg_parts = __version__.split(".")[:2] if cv_parts != pkg_parts: logger.warning( "Config version %s does not match package version %s. " "Some settings may have changed.", config_version, __version__, ) except ImportError: pass # Version not available (editable install without SCM)
[docs] @classmethod def from_yaml(cls, path: Path | str) -> ConfigManager: """Load configuration from YAML file. Args: path: Path to YAML file Returns: ConfigManager instance """ path = validate_file_exists(path, "Configuration file") with open(path, encoding="utf-8") as f: config = yaml.safe_load(f) return cls(config)
[docs] @classmethod def from_dict(cls, config: dict[str, Any]) -> ConfigManager: """Create from dictionary. Args: config: Configuration dictionary Returns: ConfigManager instance """ return cls(config)
[docs] @classmethod def from_json(cls, path: Path | str) -> ConfigManager: """Load configuration from JSON file. Args: path: Path to JSON file Returns: ConfigManager instance """ import json path = validate_file_exists(path, "Configuration file") with open(path, encoding="utf-8") as f: config = json.load(f) return cls(config)
@property def raw_config(self) -> dict[str, Any]: """Get raw configuration dictionary (deep copy to prevent mutation).""" return copy.deepcopy(self._config) # === Experimental Data === @property def data_file_path(self) -> Path: """Path to experimental data file.""" return Path(self._config["experimental_data"]["file_path"]) @property def data_folder_path(self) -> Path | None: """Optional folder path for data.""" path = self._config["experimental_data"].get("data_folder_path") return Path(path) if path else None @property def file_format(self) -> str: """Data file format.""" return cast(str, self._config["experimental_data"].get("file_format", "hdf5")) # === Cache Settings === @property def cache_file_path(self) -> Path | None: """Directory for cache files (falls back to data_folder_path, then None).""" path = self._config["experimental_data"].get("cache_file_path", "") if path: return Path(path) # Homodyne parity: fall back to data_folder_path when cache_file_path is empty return self.data_folder_path @property def cache_filename_template(self) -> str | None: """Template for cache filenames with ${variable} substitution.""" tmpl = self._config["experimental_data"].get("cache_filename_template", "") return tmpl if tmpl else None @property def cache_compression(self) -> bool: """Whether to compress cache files.""" return bool(self._config["experimental_data"].get("cache_compression", True)) # === Analyzer Parameters === @property def _ap(self) -> dict[str, Any]: """Canonical analyzer_parameters section.""" return cast(dict[str, Any], self._config["analyzer_parameters"]) @property def dt(self) -> float: """Time step [seconds].""" return float(self._ap["dt"]) @property def start_frame(self) -> int: """Starting frame (1-indexed).""" return int(self._ap["start_frame"]) @property def end_frame(self) -> int: """Ending frame (1-indexed, inclusive).""" return int(self._ap["end_frame"]) @property def time_length(self) -> int: """Number of time points (derived from frame range).""" return self.end_frame - (self.start_frame - 1) @property def t_start(self) -> int: """Starting time index, 0-indexed (derived from start_frame).""" return self.start_frame - 1 @property def wavevector_q(self) -> float: """Scattering wavevector magnitude [Å⁻¹].""" return float(self._ap["scattering"]["wavevector_q"]) @property def phi_angles(self) -> list[float] | None: """List of phi angles for analysis.""" angles = self._ap["scattering"].get("phi_angles") return [float(a) for a in angles] if angles else None @property def stator_rotor_gap(self) -> float | None: """Stator-rotor gap [Å] (optional geometry metadata).""" geom = self._ap.get("geometry", {}) gap = geom.get("stator_rotor_gap") return float(gap) if gap is not None else None # === Parameter Settings === @property def parameters_config(self) -> dict[str, Any]: """Get parameters configuration section.""" return cast(dict[str, Any], self._config.get("parameters", {}))
[docs] def get_parameter_value(self, group: str, name: str) -> float: """Get a specific parameter value. Args: group: Parameter group ('reference', 'sample', etc.) name: Parameter name within group Returns: Parameter value """ group_config = self._config["parameters"].get(group, {}) param_config = group_config.get(name, {}) if isinstance(param_config, dict): if "value" not in param_config: raise ConfigurationError( f"Parameter '{name}' in group '{group}' is missing " f"the required 'value' key" ) return float(param_config["value"]) return float(param_config)
[docs] def get_parameter_vary(self, group: str, name: str) -> bool: """Check if parameter varies in optimization. Args: group: Parameter group name: Parameter name Returns: Whether parameter varies """ group_config = self._config["parameters"].get(group, {}) param_config = group_config.get(name, {}) if isinstance(param_config, dict): return bool(param_config.get("vary", True)) return True
# === Optimization Settings === @property def optimization_method(self) -> str: """Optimization method ('nlsq' or 'cmc').""" return cast(str, self._config.get("optimization", {}).get("method", "nlsq")) @property def nlsq_config(self) -> dict[str, Any]: """NLSQ optimization settings (returns a copy to prevent mutation).""" return copy.deepcopy( cast(dict[str, Any], self._config.get("optimization", {}).get("nlsq", {})) ) @property def cmc_config(self) -> dict[str, Any]: """CMC analysis settings (returns a copy to prevent mutation).""" return copy.deepcopy( cast(dict[str, Any], self._config.get("optimization", {}).get("cmc", {})) ) def _merge_cmc_config(self) -> dict[str, Any]: """Merge CMC config with sensible defaults. Config values override defaults. Returns: Merged CMC configuration dictionary """ defaults: dict[str, Any] = { "num_warmup": 500, "num_samples": 1000, "num_chains": 4, "target_accept_prob": 0.8, "max_tree_depth": 10, } cmc = self._config.get("optimization", {}).get("cmc", {}) if isinstance(cmc, dict): defaults.update(cmc) return defaults def _validate_cmc_config(self, cmc_config: dict[str, Any]) -> list[str]: """Validate CMC config values. Args: cmc_config: CMC configuration dictionary to validate Returns: List of error messages (empty if valid) """ errors: list[str] = [] num_warmup = cmc_config.get("num_warmup") if num_warmup is not None and ( not isinstance(num_warmup, int) or num_warmup <= 0 ): errors.append(f"num_warmup must be > 0, got {num_warmup}") num_samples = cmc_config.get("num_samples") if num_samples is not None and ( not isinstance(num_samples, int) or num_samples <= 0 ): errors.append(f"num_samples must be > 0, got {num_samples}") num_chains = cmc_config.get("num_chains") if num_chains is not None and ( not isinstance(num_chains, int) or num_chains <= 0 ): errors.append(f"num_chains must be > 0, got {num_chains}") target_accept_prob = cmc_config.get("target_accept_prob") if target_accept_prob is not None and ( not isinstance(target_accept_prob, (int, float)) or target_accept_prob <= 0 or target_accept_prob >= 1 ): errors.append( f"target_accept_prob must be in (0, 1), got {target_accept_prob}" ) max_tree_depth = cmc_config.get("max_tree_depth") if max_tree_depth is not None and ( not isinstance(max_tree_depth, int) or max_tree_depth < 1 or max_tree_depth > 20 ): errors.append(f"max_tree_depth must be in [1, 20], got {max_tree_depth}") return errors
[docs] def update_optimization_config(self, section: str, key: str, value: Any) -> None: """Update a single optimization config key in-place. Args: section: Optimization sub-section ("nlsq" or "cmc"). key: Configuration key to update. value: New value for the key. """ self._config.setdefault("optimization", {}).setdefault(section, {})[key] = value
[docs] def get_config(self) -> dict[str, Any]: """Return a deep copy of the raw configuration dictionary. Returns: Deep copy of the full configuration dictionary """ return copy.deepcopy(self._config)
[docs] def get_cmc_config(self) -> dict[str, Any]: """Return merged CMC config with defaults applied. Returns: CMC configuration with defaults merged in """ return self._merge_cmc_config()
# === Output Settings === @property def output_dir(self) -> Path: """Output directory path.""" output = self._config.get("output", {}) return Path(output.get("output_dir", "./output"))
[docs] def to_yaml(self, path: Path | str) -> None: """Save configuration to YAML file. Args: path: Output path """ path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w", encoding="utf-8") as f: yaml.dump(self._config, f, default_flow_style=False, sort_keys=False)
[docs] def load_xpcs_config(path: Path | str) -> ConfigManager: """Load XPCS analysis configuration from file. Convenience function for loading configuration. Args: path: Path to YAML configuration file Returns: ConfigManager instance """ return ConfigManager.from_yaml(path)