"""Configuration for NLSQ optimization in the heterodyne analysis pipeline.
This module defines the full configuration hierarchy for non-linear least squares
fitting of heterodyne XPCS correlation functions:
- ``HybridRecoveryConfig`` — progressive retry / fallback parameters
- ``NLSQValidationConfig`` — post-fit validation thresholds
- ``NLSQConfig`` — master configuration dataclass
"""
from __future__ import annotations
import warnings
from dataclasses import dataclass, field
from typing import Any, Literal
import numpy as np
from heterodyne.utils.logging import get_logger
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Safe type-conversion utilities
# ---------------------------------------------------------------------------
_SENTINEL = object()
[docs]
def safe_float(value: Any, default: float) -> float:
"""Convert *value* to float, returning *default* on failure.
Args:
value: Arbitrary input that should be numeric.
default: Fallback value when conversion fails.
Returns:
``float(value)`` on success, *default* otherwise.
"""
if value is None:
return default
try:
return float(value)
except (TypeError, ValueError):
logger.warning(
"safe_float: could not convert %r to float, using default %s",
value,
default,
)
return default
[docs]
def safe_int(value: Any, default: int) -> int:
"""Convert *value* to int, returning *default* on failure.
Args:
value: Arbitrary input that should be integral.
default: Fallback value when conversion fails.
Returns:
``int(value)`` on success, *default* otherwise.
"""
if value is None:
return default
try:
return int(value)
except (TypeError, ValueError):
logger.warning(
"safe_int: could not convert %r to int, using default %s",
value,
default,
)
return default
# ---------------------------------------------------------------------------
# HybridRecoveryConfig
# ---------------------------------------------------------------------------
_VALID_WORKFLOWS: frozenset[str] = frozenset({"auto", "auto_global", "hpc"})
_VALID_GOALS: frozenset[str] = frozenset(
{"fast", "robust", "quality", "memory_efficient"}
)
_VALID_ANALYSIS_MODES: frozenset[str] = frozenset(
{"static_ref", "static_both", "two_component"}
)
_VALID_NLSQ_STABILITY: frozenset[str] = frozenset({"auto", "check", "off"})
[docs]
@dataclass
class HybridRecoveryConfig:
"""Progressive retry / fallback parameters for NLSQ recovery.
When a fit fails to converge, the optimizer retries with progressively
more aggressive regularisation and a smaller trust region. Each attempt
*k* (0-based) applies the following scaling to the baseline settings:
- learning rate : ``lr_decay ** k``
- regularisation : ``lambda_growth ** k``
- trust radius : ``trust_decay ** k``
Attributes:
max_retries: Maximum number of recovery attempts before giving up.
lr_decay: Multiplicative factor applied to the learning rate per retry
(< 1 shrinks the effective step size).
lambda_growth: Multiplicative factor applied to the regularisation
strength per retry (> 1 increases damping).
trust_decay: Multiplicative factor applied to the trust-region radius
per retry (< 1 tightens the constraint).
perturb_scale: Standard deviation of the Gaussian perturbation added
to the starting parameters before each retry, expressed as a
fraction of the parameter range.
"""
max_retries: int = 3
lr_decay: float = 0.5
lambda_growth: float = 10.0
trust_decay: float = 0.5
perturb_scale: float = 0.1
log_retries: bool = True
[docs]
def get_retry_settings(self, attempt: int) -> dict[str, float]:
"""Return scaled optimiser settings for a given retry attempt.
Args:
attempt: 0-based retry index. ``attempt=0`` returns the
unscaled baseline (scale factor = 1).
Returns:
Dictionary with keys ``"lr_scale"``, ``"lambda_scale"``, and
``"trust_radius_scale"``, each a multiplicative factor to apply
to the corresponding optimiser hyperparameter.
"""
if attempt < 0:
raise ValueError(f"attempt must be >= 0, got {attempt}")
return {
"lr_scale": self.lr_decay**attempt,
"lambda_scale": self.lambda_growth**attempt,
"trust_radius_scale": self.trust_decay**attempt,
}
# ---------------------------------------------------------------------------
# NLSQValidationConfig
# ---------------------------------------------------------------------------
[docs]
@dataclass
class NLSQValidationConfig:
"""Thresholds used when validating post-fit quality metrics.
Attributes:
chi2_warn_low: Reduced chi-squared below this value triggers a
warning (possible over-fitting or under-estimated errors).
chi2_warn_high: Reduced chi-squared above this value triggers a
warning (possible under-fitting or under-estimated model).
chi2_fail_high: Reduced chi-squared above this value is treated as a
hard failure.
max_relative_uncertainty: Maximum acceptable relative uncertainty
(sigma / ``|param|``) for any fitted parameter. Value of 1.0 means
100 %.
correlation_warn: Off-diagonal correlation coefficient magnitude above
this threshold triggers a collinearity warning.
"""
# Reduced chi-squared thresholds
chi2_warn_low: float = 0.5
chi2_warn_high: float = 2.0
chi2_fail_high: float = 10.0
# Uncertainty validation
max_relative_uncertainty: float = 1.0 # 100 %
# Correlation threshold for parameters
correlation_warn: float = 0.95
# ---------------------------------------------------------------------------
# NLSQConfig
# ---------------------------------------------------------------------------
[docs]
@dataclass
class NLSQConfig:
"""Master configuration for NLSQ fitting of heterodyne XPCS data.
The heterodyne model has 14 parameters organised into two-component
(signal + background) correlation functions. This configuration covers
the full pipeline: solver hyperparameters, multi-start, streaming /
chunking, recovery on failure, and post-fit diagnostics.
Attributes:
max_iterations: Maximum number of optimiser iterations per fit.
tolerance: Convergence tolerance for the cost function.
method: Trust-region algorithm variant passed to the
``nlsq CurveFit optimizer``. Note that ``dogbox`` is coerced
to ``trf`` by the strategy layer.
multistart: Whether to run multi-start optimisation to avoid local
minima.
multistart_n: Number of random starting points when *multistart* is
enabled.
verbose: Verbosity level forwarded to the solver (0 = silent,
1 = summary, 2 = detailed).
use_jac: Whether to supply an analytic Jacobian to the solver.
x_scale: Parameter scaling strategy. ``"jac"`` uses the Jacobian
diagonal; a list of floats provides explicit per-parameter scales.
ftol: Relative tolerance on the cost function change.
xtol: Relative tolerance on the parameter step norm.
gtol: Absolute tolerance on the projected gradient norm.
loss: Robust loss function kernel.
diff_step: Finite-difference step size. ``None`` selects the
solver default.
max_nfev: Hard cap on function evaluations. ``None`` is unlimited.
chunk_size: Number of q-points per processing chunk. ``None`` means
auto-select based on available memory.
workflow: High-level workflow preset. One of ``"auto"``,
``"auto_global"``, ``"hpc"``.
goal: Optimisation goal preset controlling the balance between speed,
robustness, and solution quality. One of ``"fast"``,
``"robust"``, ``"quality"``, ``"memory_efficient"``.
enable_streaming: Process data in a streaming fashion (chunk-by-chunk)
rather than loading all q-points at once.
streaming_chunk_size: Number of q-points per streaming chunk when
*enable_streaming* is ``True``.
enable_stratified: Use stratified sampling across q-point subsets.
target_chunk_size: Target number of data points per stratified chunk.
enable_recovery: Automatically retry failed fits with more aggressive
regularisation (see *recovery_config*).
max_recovery_attempts: Maximum retries before a fit is declared failed.
recovery_config: Per-retry scaling parameters.
enable_diagnostics: Emit structured convergence / quality diagnostics
after each fit.
enable_anti_degeneracy: Apply anti-degeneracy constraints to prevent
parameter collapse (e.g. two identical relaxation modes).
x_scale_map: Per-parameter scale overrides keyed by parameter name.
Entries here are merged into (and override) the default
Jacobian-based scaling.
loss_weights: Per-data-point loss weights. ``None`` uses uniform
weighting.
loss_scale: Global scale factor applied to the loss function value
before passing to the solver.
tr_solver: Trust-region sub-problem solver override (``"exact"``,
``"lsmr"``, or ``None`` for solver default).
step_bound: Upper bound on the step norm relative to the trust radius.
``0.0`` defers to the solver default.
use_nlsq_library: Prefer the ``nlsq`` library over the scipy fallback.
n_params: Number of model parameters. Fixed at 14 for heterodyne.
analysis_mode: Which physical model variant to use. One of
``"static_ref"`` (reference beam treated as static background),
``"static_both"`` (both beams treated as static),
``"two_component"`` (full two-component model, default).
validation: Post-fit validation thresholds.
"""
# ------------------------------------------------------------------
# Existing / core solver fields
# ------------------------------------------------------------------
max_iterations: int = 1000
tolerance: float = 1e-8
method: Literal["trf", "lm", "dogbox"] = "trf"
multistart: bool = False
multistart_n: int = 10
verbose: int = 1
use_jac: bool = True
x_scale: str | list[float] = "jac"
ftol: float = 1e-8
xtol: float = 1e-8
gtol: float = 1e-8
loss: Literal["linear", "soft_l1", "huber", "cauchy", "arctan"] = "soft_l1"
trust_region_scale: float = 1.0
# Advanced solver options
diff_step: float | None = None
max_nfev: int | None = None
# Memory management
chunk_size: int | None = None # None for auto
# ------------------------------------------------------------------
# Workflow / goal presets
# ------------------------------------------------------------------
workflow: str = "auto"
goal: str = "robust"
# ------------------------------------------------------------------
# Streaming and stratified sampling
# ------------------------------------------------------------------
enable_streaming: bool = False
streaming_chunk_size: int = 50000
enable_stratified: bool = False
target_chunk_size: int = 10000
# ------------------------------------------------------------------
# Recovery
# ------------------------------------------------------------------
enable_recovery: bool = True
max_recovery_attempts: int = 3
recovery_config: HybridRecoveryConfig = field(default_factory=HybridRecoveryConfig)
# ------------------------------------------------------------------
# Diagnostics and anti-degeneracy
# ------------------------------------------------------------------
enable_diagnostics: bool = True
enable_anti_degeneracy: bool = True
# ------------------------------------------------------------------
# Loss and scaling overrides
# ------------------------------------------------------------------
x_scale_map: dict[str, float] = field(default_factory=dict)
loss_weights: list[float] | None = None
loss_scale: float = 1.0
tr_solver: str | None = None
step_bound: float = 0.0
# ------------------------------------------------------------------
# Fourier reparameterization for per-angle scaling
# ------------------------------------------------------------------
per_angle_mode: Literal[
"individual", "constant", "fourier", "auto", "independent"
] = "auto"
fourier_order: int = 2
fourier_auto_threshold: int = 6
# ------------------------------------------------------------------
# Hierarchical optimization
# ------------------------------------------------------------------
enable_hierarchical: bool = False
hierarchical_max_outer_iterations: int = 20
hierarchical_inner_tolerance: float = 1e-6
hierarchical_outer_tolerance: float = 1e-4
# Homodyne-parity hierarchical fields
hierarchical_physical_max_iterations: int = 100
hierarchical_per_angle_max_iterations: int = 50
# ------------------------------------------------------------------
# Adaptive regularization
# ------------------------------------------------------------------
regularization_mode: Literal["none", "tikhonov", "adaptive"] = "none"
group_variance_lambda: float = 0.01
regularization_target_cv: float = 0.5
# Homodyne-parity regularization fields
regularization_target_contribution: float = 0.10 # 10% of MSE contribution
regularization_max_cv: float = 0.20 # 20% max variation
regularization_auto_tune_lambda: bool = True
# ------------------------------------------------------------------
# Gradient collapse detection
# ------------------------------------------------------------------
enable_gradient_monitoring: bool = False
gradient_ratio_threshold: float = 100.0
gradient_consecutive_triggers: int = 3
# Homodyne-parity gradient collapse response field
gradient_collapse_response: str = (
"hierarchical" # "warn", "hierarchical", "reset", "abort"
)
# ------------------------------------------------------------------
# CMA-ES global search (legacy heterodyne fields)
# ------------------------------------------------------------------
enable_cmaes: bool = False
cmaes_sigma0: float = 0.3
cmaes_max_iterations: int = 1000
cmaes_population_size: int | None = None
cmaes_tolx: float = 1e-6
cmaes_tolfun: float = 1e-8
cmaes_diagonal_filtering: str = "remove"
cmaes_anti_degeneracy: bool = False
cmaes_warmstart_auto_skip: bool = True
cmaes_warmstart_skip_threshold: float = 5.0
cmaes_restart_strategy: str = "bipop"
cmaes_max_restarts: int = 9
# ------------------------------------------------------------------
# CMA-ES global search (homodyne-parity fields, NLSQ v0.6.4+)
# ------------------------------------------------------------------
cmaes_preset: str = "cmaes" # "cmaes-fast", "cmaes", "cmaes-global"
cmaes_max_generations: int | None = None # None = use preset + adaptive scaling
cmaes_popsize: int | None = None # Population size (None = auto)
cmaes_sigma: float = 0.5 # Initial step size (fraction of search range)
cmaes_sigma_warmstart: float = 0.05 # Reduced sigma for warm-start mode
cmaes_tol_fun: float = 1e-8 # Function value tolerance for convergence
cmaes_tol_x: float = 1e-8 # Parameter tolerance for convergence
cmaes_population_batch_size: int | None = None # Memory batching (None = auto)
cmaes_data_chunk_size: int | None = None # Data streaming (None = auto)
cmaes_refine_with_nlsq: bool = True # Refine CMA-ES solution with NLSQ TRF
cmaes_auto_select: bool = (
True # Auto-select CMA-ES vs multi-start based on scale ratio
)
cmaes_scale_threshold: float = 1000.0 # Scale ratio threshold for auto-selection
cmaes_memory_limit_gb: float = 8.0 # Memory limit for auto-configuration
# Post-CMA-ES NLSQ TRF Refinement
cmaes_refinement_workflow: str = "auto" # "auto", "standard", "streaming"
cmaes_refinement_ftol: float = 1e-10
cmaes_refinement_xtol: float = 1e-10
cmaes_refinement_gtol: float = 1e-10
cmaes_refinement_max_nfev: int = 500
cmaes_refinement_loss: str = "linear" # "linear", "soft_l1", "huber"
# CMA-ES Parameter Normalization
cmaes_normalize: bool = True # Enable bounds-based normalization
cmaes_normalization_epsilon: float = 1e-12 # Prevent division by zero
# ------------------------------------------------------------------
# Fit Quality Validation (homodyne-parity fields, v2.16.0)
# ------------------------------------------------------------------
enable_quality_validation: bool = True # Enable post-fit quality checks
quality_reduced_chi_squared_threshold: float = 10.0 # Warn if χ²_red > threshold
quality_warn_on_max_restarts: bool = True # Warn if CMA-ES didn't converge
quality_warn_on_bounds_hit: bool = True # Warn if physical params at bounds
quality_warn_on_convergence_failure: bool = True # Warn if optimization failed
quality_bounds_tolerance: float = 1e-9 # Tolerance for "at bounds" detection
# ------------------------------------------------------------------
# Progress and logging settings
# ------------------------------------------------------------------
enable_progress_bar: bool = True # Show tqdm progress bar during fitting
log_iteration_interval: int = 10 # Log every N iterations (for verbose >= 2)
# ------------------------------------------------------------------
# Hybrid streaming optimizer (legacy heterodyne fields)
# ------------------------------------------------------------------
hybrid_enable: bool = False
hybrid_warmup_fraction: float = 0.1
hybrid_normalization: bool = True
hybrid_method: Literal["lbfgs", "gauss_newton"] = "gauss_newton"
hybrid_lbfgs_memory: int = 10
hybrid_convergence_window: int = 5
hybrid_convergence_threshold: float = 1e-6
hybrid_max_phases: int = 4
# ------------------------------------------------------------------
# Hybrid streaming optimizer (homodyne-parity fields, v2.6.0+)
# ------------------------------------------------------------------
enable_hybrid_streaming: bool = True
hybrid_normalize: bool = True
hybrid_normalization_strategy: str = "auto" # 'auto', 'bounds', 'p0', 'none'
hybrid_warmup_iterations: int = 200
hybrid_max_warmup_iterations: int = 500
hybrid_warmup_learning_rate: float = 0.001
hybrid_gauss_newton_max_iterations: int = 100
hybrid_gauss_newton_tol: float = 1e-8
hybrid_chunk_size: int = 10000
hybrid_trust_region_initial: float = 1.0
hybrid_regularization_factor: float = 1e-10
hybrid_enable_checkpoints: bool = True
hybrid_checkpoint_frequency: int = 100
hybrid_validate_numerics: bool = True
# 4-Layer Defense Strategy for L-BFGS Warmup (v2.8.0 / NLSQ 0.3.6)
# Layer 1: Warm Start Detection
hybrid_enable_warm_start_detection: bool = True
hybrid_warm_start_threshold: float = 0.01 # Skip if loss/variance < this
# Layer 2: Adaptive Learning Rate
hybrid_enable_adaptive_warmup_lr: bool = True
hybrid_warmup_lr_refinement: float = 1e-6 # LR for good starts
hybrid_warmup_lr_careful: float = 1e-5 # LR for moderate starts
# Layer 3: Cost-Increase Guard
hybrid_enable_cost_guard: bool = True
hybrid_cost_increase_tolerance: float = 0.05 # Abort if loss increases >5%
# Layer 4: Step Clipping
hybrid_enable_step_clipping: bool = True
hybrid_max_warmup_step_size: float = 0.1 # Max step in normalized units
# ------------------------------------------------------------------
# Multi-start extensions (legacy heterodyne fields)
# ------------------------------------------------------------------
sampling_strategy: Literal["lhs", "sobol", "random"] = "lhs"
screen_keep_fraction: float = 0.5
refine_top_k: int = 3
# ------------------------------------------------------------------
# Multi-start optimization settings (homodyne-parity fields, v2.6.0)
# ------------------------------------------------------------------
enable_multi_start: bool = False # Default OFF - user opt-in
multi_start_n_starts: int = 10
multi_start_seed: int = 42
multi_start_sampling_strategy: str = (
"latin_hypercube" # 'latin_hypercube' or 'random'
)
multi_start_n_workers: int = 0 # 0 = auto (min of n_starts, cpu_count)
multi_start_use_screening: bool = True
multi_start_screen_keep_fraction: float = 0.5
multi_start_refine_top_k: int = 3
multi_start_refinement_ftol: float = 1e-12
multi_start_degeneracy_threshold: float = 0.1
# ------------------------------------------------------------------
# Scaling threshold
# ------------------------------------------------------------------
constant_scaling_threshold: int = 3
# ------------------------------------------------------------------
# Backend and model identity
# ------------------------------------------------------------------
use_nlsq_library: bool = True
n_params: int = 14 # heterodyne: 14 parameters
analysis_mode: str = "two_component"
# ------------------------------------------------------------------
# NLSQ package integration (mirrors homodyne wrapper.py)
# ------------------------------------------------------------------
nlsq_stability: str = "auto" # 'auto', 'check', or 'off'
nlsq_rescale_data: bool = False # xdata is indices, not physical
nlsq_x_scale: str | np.ndarray = "jac" # trust-region scaling
nlsq_memory_fraction: float = 0.75 # fraction of RAM for NLSQ
nlsq_memory_fallback_gb: float = 16.0 # fallback if detection fails
# ------------------------------------------------------------------
# Post-fit validation
# ------------------------------------------------------------------
validation: NLSQValidationConfig = field(default_factory=NLSQValidationConfig)
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
[docs]
def __post_init__(self) -> None:
"""Validate invariants that must hold immediately after construction."""
# Normalise deprecated mode names to their canonical homodyne form.
if self.per_angle_mode == "independent":
warnings.warn(
"per_angle_mode='independent' is deprecated; use 'individual' "
"(matches homodyne's canonical name per "
"https://homodyne.readthedocs.io/en/latest/theory/anti_degeneracy.html). "
"'independent' will be removed in heterodyne v1.0.",
DeprecationWarning,
stacklevel=3,
)
self.per_angle_mode = (
"individual" # Literal includes both names during deprecation window
)
if self.max_iterations < 1:
raise ValueError("max_iterations must be >= 1")
if self.tolerance <= 0:
raise ValueError("tolerance must be positive")
if self.multistart_n < 1:
raise ValueError("multistart_n must be >= 1")
if self.streaming_chunk_size < 1:
raise ValueError("streaming_chunk_size must be >= 1")
if self.target_chunk_size < 1:
raise ValueError("target_chunk_size must be >= 1")
if self.max_recovery_attempts < 0:
raise ValueError("max_recovery_attempts must be >= 0")
if self.loss_scale <= 0:
raise ValueError("loss_scale must be positive")
if self.hierarchical_max_outer_iterations < 1:
raise ValueError("hierarchical_max_outer_iterations must be >= 1")
if self.gradient_consecutive_triggers < 1:
raise ValueError("gradient_consecutive_triggers must be >= 1")
if self.cmaes_sigma0 <= 0:
raise ValueError("cmaes_sigma0 must be > 0")
if self.cmaes_diagonal_filtering not in ("remove", "none"):
raise ValueError(
f"cmaes_diagonal_filtering must be 'remove' or 'none', "
f"got {self.cmaes_diagonal_filtering!r}"
)
if self.cmaes_warmstart_skip_threshold <= 0:
raise ValueError("cmaes_warmstart_skip_threshold must be > 0")
if self.cmaes_restart_strategy not in ("bipop", "none"):
raise ValueError(
f"cmaes_restart_strategy must be 'bipop' or 'none', "
f"got {self.cmaes_restart_strategy!r}"
)
if self.cmaes_max_restarts < 0:
raise ValueError("cmaes_max_restarts must be >= 0")
if not (0 < self.hybrid_warmup_fraction < 1):
raise ValueError("hybrid_warmup_fraction must be in (0, 1)")
if not (0 < self.screen_keep_fraction <= 1):
raise ValueError("screen_keep_fraction must be in (0, 1]")
if self.refine_top_k < 1:
raise ValueError("refine_top_k must be >= 1")
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
[docs]
def validate(self) -> list[str]:
"""Return a list of configuration error strings.
An empty list means the configuration is consistent. Callers should
treat a non-empty list as a hard error before launching a fit.
Returns:
List of human-readable error strings, one per violation found.
"""
errors: list[str] = []
if self.workflow not in _VALID_WORKFLOWS:
errors.append(
f"workflow={self.workflow!r} is not valid; "
f"must be one of {sorted(_VALID_WORKFLOWS)}"
)
if self.goal not in _VALID_GOALS:
errors.append(
f"goal={self.goal!r} is not valid; "
f"must be one of {sorted(_VALID_GOALS)}"
)
if self.tolerance <= 0:
errors.append(f"tolerance={self.tolerance} must be > 0")
if self.streaming_chunk_size <= 0:
errors.append(
f"streaming_chunk_size={self.streaming_chunk_size} must be > 0"
)
if self.analysis_mode not in _VALID_ANALYSIS_MODES:
errors.append(
f"analysis_mode={self.analysis_mode!r} is not valid; "
f"must be one of {sorted(_VALID_ANALYSIS_MODES)}"
)
valid_per_angle_modes = (
"individual",
"constant",
"fourier",
"auto",
"independent",
)
if self.per_angle_mode not in valid_per_angle_modes:
errors.append(
f"per_angle_mode={self.per_angle_mode!r} is not valid; "
f"must be one of {valid_per_angle_modes}"
)
if self.fourier_order < 1:
errors.append(f"fourier_order={self.fourier_order} must be >= 1")
if self.fourier_auto_threshold < 1:
errors.append(
f"fourier_auto_threshold={self.fourier_auto_threshold} must be >= 1"
)
valid_regularization_modes = ("none", "tikhonov", "adaptive")
if self.regularization_mode not in valid_regularization_modes:
errors.append(
f"regularization_mode={self.regularization_mode!r} is not valid; "
f"must be one of {valid_regularization_modes}"
)
valid_hybrid_methods = ("lbfgs", "gauss_newton")
if self.hybrid_method not in valid_hybrid_methods:
errors.append(
f"hybrid_method={self.hybrid_method!r} is not valid; "
f"must be one of {valid_hybrid_methods}"
)
valid_sampling_strategies = ("lhs", "sobol", "random")
if self.sampling_strategy not in valid_sampling_strategies:
errors.append(
f"sampling_strategy={self.sampling_strategy!r} is not valid; "
f"must be one of {valid_sampling_strategies}"
)
if self.nlsq_stability not in _VALID_NLSQ_STABILITY:
errors.append(
f"nlsq_stability={self.nlsq_stability!r} is not valid; "
f"must be one of {sorted(_VALID_NLSQ_STABILITY)}"
)
if not (0 < self.nlsq_memory_fraction <= 1):
errors.append(
f"nlsq_memory_fraction={self.nlsq_memory_fraction} must be in (0, 1]"
)
if self.nlsq_memory_fallback_gb <= 0:
errors.append(
f"nlsq_memory_fallback_gb={self.nlsq_memory_fallback_gb} must be > 0"
)
# Advisory warnings — not errors, but worth surfacing at validate() time
if self.gtol < 1e-7 and self.loss != "linear":
logger.warning(
"NLSQConfig: gtol=%.2e is very tight for loss=%r. "
"Robust loss landscapes are harder — consider gtol >= 1e-6 "
"to avoid premature max_nfev exhaustion.",
self.gtol,
self.loss,
)
if self.max_nfev is None:
logger.debug(
"NLSQConfig: max_nfev=None — nlsq defaults to 100×n_params "
"(e.g. 1400 for 14 params). Set explicitly to override.",
)
return errors
# ------------------------------------------------------------------
# Serialisation
# ------------------------------------------------------------------
[docs]
@classmethod
def from_dict(cls, config: dict[str, Any]) -> NLSQConfig:
"""Construct an ``NLSQConfig`` from a plain dictionary.
Nested sub-dictionaries under ``"recovery"`` and ``"validation"``
are automatically parsed into their respective dataclasses.
Unrecognised top-level keys are logged as warnings and ignored.
Args:
config: Flat or nested configuration dictionary, e.g. loaded from
a YAML file.
Returns:
Fully populated ``NLSQConfig`` instance.
"""
known_scalar_fields: dict[str, str] = {
# Core solver
"max_iterations": "int",
"tolerance": "float",
"method": "str",
"multistart": "bool",
"multistart_n": "int",
"verbose": "int",
"use_jac": "bool",
"x_scale": "passthrough", # str or list — handled separately
"ftol": "float",
"xtol": "float",
"gtol": "float",
"loss": "str",
"diff_step": "float_or_none",
"max_nfev": "int_or_none",
"chunk_size": "int_or_none",
# Workflow / goal
"workflow": "str",
"goal": "str",
# Streaming / stratified
"enable_streaming": "bool",
"streaming_chunk_size": "int",
"enable_stratified": "bool",
"target_chunk_size": "int",
# Recovery
"enable_recovery": "bool",
"max_recovery_attempts": "int",
# Diagnostics
"enable_diagnostics": "bool",
"enable_anti_degeneracy": "bool",
# Loss / scaling
"loss_weights": "passthrough", # list[float] | None
"loss_scale": "float",
"tr_solver": "str_or_none",
"step_bound": "float",
# Fourier reparameterization
"per_angle_mode": "str",
"fourier_order": "int",
"fourier_auto_threshold": "int",
# Hierarchical optimization
"enable_hierarchical": "bool",
"hierarchical_max_outer_iterations": "int",
"hierarchical_inner_tolerance": "float",
"hierarchical_outer_tolerance": "float",
# Adaptive regularization
"regularization_mode": "str",
"group_variance_lambda": "float",
"regularization_target_cv": "float",
# Gradient collapse detection
"enable_gradient_monitoring": "bool",
"gradient_ratio_threshold": "float",
"gradient_consecutive_triggers": "int",
# Hierarchical optimization (homodyne-parity)
"hierarchical_physical_max_iterations": "int",
"hierarchical_per_angle_max_iterations": "int",
# Adaptive regularization (homodyne-parity)
"regularization_target_contribution": "float",
"regularization_max_cv": "float",
"regularization_auto_tune_lambda": "bool",
# Gradient collapse detection (homodyne-parity)
"gradient_collapse_response": "str",
# CMA-ES global search (legacy heterodyne fields)
"enable_cmaes": "bool",
"cmaes_sigma0": "float",
"cmaes_max_iterations": "int",
"cmaes_population_size": "int_or_none",
"cmaes_tolx": "float",
"cmaes_tolfun": "float",
"cmaes_diagonal_filtering": "str",
"cmaes_anti_degeneracy": "bool",
"cmaes_warmstart_auto_skip": "bool",
"cmaes_warmstart_skip_threshold": "float",
"cmaes_restart_strategy": "str",
"cmaes_max_restarts": "int",
# CMA-ES global search (homodyne-parity fields)
"cmaes_preset": "str",
"cmaes_max_generations": "int_or_none",
"cmaes_popsize": "int_or_none",
"cmaes_sigma": "float",
"cmaes_sigma_warmstart": "float",
"cmaes_tol_fun": "float",
"cmaes_tol_x": "float",
"cmaes_population_batch_size": "int_or_none",
"cmaes_data_chunk_size": "int_or_none",
"cmaes_refine_with_nlsq": "bool",
"cmaes_auto_select": "bool",
"cmaes_scale_threshold": "float",
"cmaes_memory_limit_gb": "float",
"cmaes_refinement_workflow": "str",
"cmaes_refinement_ftol": "float",
"cmaes_refinement_xtol": "float",
"cmaes_refinement_gtol": "float",
"cmaes_refinement_max_nfev": "int",
"cmaes_refinement_loss": "str",
"cmaes_normalize": "bool",
"cmaes_normalization_epsilon": "float",
# Progress and logging (homodyne-parity)
"enable_progress_bar": "bool",
"log_iteration_interval": "int",
# Hybrid streaming optimizer (legacy heterodyne fields)
"hybrid_enable": "bool",
"hybrid_warmup_fraction": "float",
"hybrid_normalization": "bool",
"hybrid_method": "str",
"hybrid_lbfgs_memory": "int",
"hybrid_convergence_window": "int",
"hybrid_convergence_threshold": "float",
"hybrid_max_phases": "int",
# Hybrid streaming optimizer (homodyne-parity fields)
"enable_hybrid_streaming": "bool",
"hybrid_normalize": "bool",
"hybrid_normalization_strategy": "str",
"hybrid_warmup_iterations": "int",
"hybrid_max_warmup_iterations": "int",
"hybrid_warmup_learning_rate": "float",
"hybrid_gauss_newton_max_iterations": "int",
"hybrid_gauss_newton_tol": "float",
"hybrid_chunk_size": "int",
"hybrid_trust_region_initial": "float",
"hybrid_regularization_factor": "float",
"hybrid_enable_checkpoints": "bool",
"hybrid_checkpoint_frequency": "int",
"hybrid_validate_numerics": "bool",
"hybrid_enable_warm_start_detection": "bool",
"hybrid_warm_start_threshold": "float",
"hybrid_enable_adaptive_warmup_lr": "bool",
"hybrid_warmup_lr_refinement": "float",
"hybrid_warmup_lr_careful": "float",
"hybrid_enable_cost_guard": "bool",
"hybrid_cost_increase_tolerance": "float",
"hybrid_enable_step_clipping": "bool",
"hybrid_max_warmup_step_size": "float",
# Multi-start extensions (legacy heterodyne fields)
"sampling_strategy": "str",
"screen_keep_fraction": "float",
"refine_top_k": "int",
# Multi-start optimization (homodyne-parity fields)
"enable_multi_start": "bool",
"multi_start_n_starts": "int",
"multi_start_seed": "int",
"multi_start_sampling_strategy": "str",
"multi_start_n_workers": "int",
"multi_start_use_screening": "bool",
"multi_start_screen_keep_fraction": "float",
"multi_start_refine_top_k": "int",
"multi_start_refinement_ftol": "float",
"multi_start_degeneracy_threshold": "float",
# Fit quality validation (homodyne-parity fields)
"enable_quality_validation": "bool",
"quality_reduced_chi_squared_threshold": "float",
"quality_warn_on_max_restarts": "bool",
"quality_warn_on_bounds_hit": "bool",
"quality_warn_on_convergence_failure": "bool",
"quality_bounds_tolerance": "float",
# Scaling threshold
"constant_scaling_threshold": "int",
# Backend / model
"use_nlsq_library": "bool",
"n_params": "int",
"analysis_mode": "str",
# NLSQ package integration
"nlsq_stability": "str",
"nlsq_rescale_data": "bool",
"nlsq_x_scale": "passthrough", # str or np.ndarray
"nlsq_memory_fraction": "float",
"nlsq_memory_fallback_gb": "float",
# Loss function scale (homodyne-parity)
"trust_region_scale": "float",
}
normalized_config = dict(config)
def _set_from_nested(field_name: str, value: Any) -> None:
if value is not _SENTINEL and field_name not in normalized_config:
normalized_config[field_name] = value
raw_anti_degeneracy = config.get("anti_degeneracy")
if isinstance(raw_anti_degeneracy, dict):
_set_from_nested(
"per_angle_mode",
raw_anti_degeneracy.get("per_angle_mode", _SENTINEL),
)
_set_from_nested(
"fourier_order",
raw_anti_degeneracy.get("fourier_order", _SENTINEL),
)
_set_from_nested(
"fourier_auto_threshold",
raw_anti_degeneracy.get("fourier_auto_threshold", _SENTINEL),
)
_set_from_nested(
"constant_scaling_threshold",
raw_anti_degeneracy.get("constant_scaling_threshold", _SENTINEL),
)
hierarchical = raw_anti_degeneracy.get("hierarchical")
if isinstance(hierarchical, dict):
_set_from_nested(
"enable_hierarchical", hierarchical.get("enable", _SENTINEL)
)
_set_from_nested(
"hierarchical_max_outer_iterations",
hierarchical.get("max_outer_iterations", _SENTINEL),
)
_set_from_nested(
"hierarchical_inner_tolerance",
hierarchical.get("inner_tolerance", _SENTINEL),
)
_set_from_nested(
"hierarchical_outer_tolerance",
hierarchical.get("outer_tolerance", _SENTINEL),
)
# Homodyne-parity hierarchical fields
_set_from_nested(
"hierarchical_physical_max_iterations",
hierarchical.get("physical_max_iterations", _SENTINEL),
)
_set_from_nested(
"hierarchical_per_angle_max_iterations",
hierarchical.get("per_angle_max_iterations", _SENTINEL),
)
elif hierarchical is not None:
logger.warning(
"NLSQConfig.from_dict: anti_degeneracy.hierarchical must be a "
"dict, got %r — ignoring",
type(hierarchical).__name__,
)
regularization = raw_anti_degeneracy.get("regularization")
if isinstance(regularization, dict):
_set_from_nested(
"regularization_mode", regularization.get("mode", _SENTINEL)
)
_set_from_nested(
"group_variance_lambda", regularization.get("lambda", _SENTINEL)
)
_set_from_nested(
"regularization_target_cv",
regularization.get("target_cv", _SENTINEL),
)
# Homodyne-parity regularization fields
_set_from_nested(
"regularization_target_contribution",
regularization.get("target_contribution", _SENTINEL),
)
_set_from_nested(
"regularization_max_cv",
regularization.get("max_cv", _SENTINEL),
)
_set_from_nested(
"regularization_auto_tune_lambda",
regularization.get("auto_tune_lambda", _SENTINEL),
)
elif regularization is not None:
logger.warning(
"NLSQConfig.from_dict: anti_degeneracy.regularization must be a "
"dict, got %r — ignoring",
type(regularization).__name__,
)
gradient_monitoring = raw_anti_degeneracy.get("gradient_monitoring")
if isinstance(gradient_monitoring, dict):
_set_from_nested(
"enable_gradient_monitoring",
gradient_monitoring.get("enable", _SENTINEL),
)
_set_from_nested(
"gradient_ratio_threshold",
gradient_monitoring.get("ratio_threshold", _SENTINEL),
)
_set_from_nested(
"gradient_consecutive_triggers",
gradient_monitoring.get("consecutive_triggers", _SENTINEL),
)
# Homodyne-parity gradient collapse response field
_set_from_nested(
"gradient_collapse_response",
gradient_monitoring.get("response", _SENTINEL),
)
elif gradient_monitoring is not None:
logger.warning(
"NLSQConfig.from_dict: anti_degeneracy.gradient_monitoring must "
"be a dict, got %r — ignoring",
type(gradient_monitoring).__name__,
)
elif raw_anti_degeneracy is not None:
logger.warning(
"NLSQConfig.from_dict: 'anti_degeneracy' must be a dict, got %r — "
"ignoring",
type(raw_anti_degeneracy).__name__,
)
raw_cmaes = config.get("cmaes")
if isinstance(raw_cmaes, dict):
_set_from_nested("enable_cmaes", raw_cmaes.get("enable", _SENTINEL))
_set_from_nested(
"cmaes_sigma0",
raw_cmaes.get("sigma", raw_cmaes.get("sigma0", _SENTINEL)),
)
_set_from_nested(
"cmaes_max_iterations",
raw_cmaes.get(
"max_generations",
raw_cmaes.get("max_iterations", _SENTINEL),
),
)
_set_from_nested(
"cmaes_population_size",
raw_cmaes.get("popsize", raw_cmaes.get("population_size", _SENTINEL)),
)
_set_from_nested(
"cmaes_tolx", raw_cmaes.get("tol_x", raw_cmaes.get("tolx", _SENTINEL))
)
_set_from_nested(
"cmaes_tolfun",
raw_cmaes.get("tol_fun", raw_cmaes.get("tolfun", _SENTINEL)),
)
_set_from_nested(
"cmaes_diagonal_filtering",
raw_cmaes.get("diagonal_filtering", _SENTINEL),
)
_set_from_nested(
"cmaes_anti_degeneracy",
raw_cmaes.get("anti_degeneracy", _SENTINEL),
)
_set_from_nested(
"cmaes_warmstart_auto_skip",
raw_cmaes.get("warmstart_auto_skip", _SENTINEL),
)
_set_from_nested(
"cmaes_warmstart_skip_threshold",
raw_cmaes.get("warmstart_skip_threshold", _SENTINEL),
)
# Homodyne-parity CMA-ES fields
_set_from_nested("cmaes_preset", raw_cmaes.get("preset", _SENTINEL))
_set_from_nested(
"cmaes_max_generations",
raw_cmaes.get("max_generations", _SENTINEL),
)
_set_from_nested("cmaes_popsize", raw_cmaes.get("popsize", _SENTINEL))
_set_from_nested("cmaes_sigma", raw_cmaes.get("sigma", _SENTINEL))
_set_from_nested(
"cmaes_sigma_warmstart",
raw_cmaes.get("sigma_warmstart", _SENTINEL),
)
_set_from_nested("cmaes_tol_fun", raw_cmaes.get("tol_fun", _SENTINEL))
_set_from_nested("cmaes_tol_x", raw_cmaes.get("tol_x", _SENTINEL))
_set_from_nested(
"cmaes_population_batch_size",
raw_cmaes.get("population_batch_size", _SENTINEL),
)
_set_from_nested(
"cmaes_data_chunk_size",
raw_cmaes.get("data_chunk_size", _SENTINEL),
)
_set_from_nested(
"cmaes_refine_with_nlsq",
raw_cmaes.get("refine_with_nlsq", _SENTINEL),
)
_set_from_nested(
"cmaes_auto_select", raw_cmaes.get("auto_select", _SENTINEL)
)
_set_from_nested(
"cmaes_scale_threshold",
raw_cmaes.get("scale_threshold", _SENTINEL),
)
_set_from_nested(
"cmaes_memory_limit_gb",
raw_cmaes.get("memory_limit_gb", _SENTINEL),
)
_set_from_nested(
"cmaes_refinement_workflow",
raw_cmaes.get("refinement_workflow", _SENTINEL),
)
_set_from_nested(
"cmaes_refinement_ftol",
raw_cmaes.get("refinement_ftol", _SENTINEL),
)
_set_from_nested(
"cmaes_refinement_xtol",
raw_cmaes.get("refinement_xtol", _SENTINEL),
)
_set_from_nested(
"cmaes_refinement_gtol",
raw_cmaes.get("refinement_gtol", _SENTINEL),
)
_set_from_nested(
"cmaes_refinement_max_nfev",
raw_cmaes.get("refinement_max_nfev", _SENTINEL),
)
_set_from_nested(
"cmaes_refinement_loss",
raw_cmaes.get("refinement_loss", _SENTINEL),
)
_set_from_nested("cmaes_normalize", raw_cmaes.get("normalize", _SENTINEL))
_set_from_nested(
"cmaes_normalization_epsilon",
raw_cmaes.get("normalization_epsilon", _SENTINEL),
)
elif raw_cmaes is not None:
logger.warning(
"NLSQConfig.from_dict: 'cmaes' must be a dict, got %r — ignoring",
type(raw_cmaes).__name__,
)
# Homodyne-parity: progress section
raw_progress = config.get("progress")
if isinstance(raw_progress, dict):
_set_from_nested(
"enable_progress_bar", raw_progress.get("enable", _SENTINEL)
)
_set_from_nested("verbose", raw_progress.get("verbose", _SENTINEL))
_set_from_nested(
"log_iteration_interval", raw_progress.get("log_interval", _SENTINEL)
)
elif raw_progress is not None:
logger.warning(
"NLSQConfig.from_dict: 'progress' must be a dict, got %r — ignoring",
type(raw_progress).__name__,
)
# Homodyne-parity: hybrid_streaming section
raw_hybrid_streaming = config.get("hybrid_streaming")
if isinstance(raw_hybrid_streaming, dict):
_set_from_nested(
"enable_hybrid_streaming",
raw_hybrid_streaming.get("enable", _SENTINEL),
)
_set_from_nested(
"hybrid_normalize", raw_hybrid_streaming.get("normalize", _SENTINEL)
)
_set_from_nested(
"hybrid_normalization_strategy",
raw_hybrid_streaming.get("normalization_strategy", _SENTINEL),
)
_set_from_nested(
"hybrid_warmup_iterations",
raw_hybrid_streaming.get("warmup_iterations", _SENTINEL),
)
_set_from_nested(
"hybrid_max_warmup_iterations",
raw_hybrid_streaming.get("max_warmup_iterations", _SENTINEL),
)
_set_from_nested(
"hybrid_warmup_learning_rate",
raw_hybrid_streaming.get("warmup_learning_rate", _SENTINEL),
)
_set_from_nested(
"hybrid_gauss_newton_max_iterations",
raw_hybrid_streaming.get("gauss_newton_max_iterations", _SENTINEL),
)
_set_from_nested(
"hybrid_gauss_newton_tol",
raw_hybrid_streaming.get("gauss_newton_tol", _SENTINEL),
)
_set_from_nested(
"hybrid_chunk_size", raw_hybrid_streaming.get("chunk_size", _SENTINEL)
)
_set_from_nested(
"hybrid_trust_region_initial",
raw_hybrid_streaming.get("trust_region_initial", _SENTINEL),
)
_set_from_nested(
"hybrid_regularization_factor",
raw_hybrid_streaming.get("regularization_factor", _SENTINEL),
)
_set_from_nested(
"hybrid_enable_checkpoints",
raw_hybrid_streaming.get("enable_checkpoints", _SENTINEL),
)
_set_from_nested(
"hybrid_checkpoint_frequency",
raw_hybrid_streaming.get("checkpoint_frequency", _SENTINEL),
)
_set_from_nested(
"hybrid_validate_numerics",
raw_hybrid_streaming.get("validate_numerics", _SENTINEL),
)
_set_from_nested(
"hybrid_enable_warm_start_detection",
raw_hybrid_streaming.get("enable_warm_start_detection", _SENTINEL),
)
_set_from_nested(
"hybrid_warm_start_threshold",
raw_hybrid_streaming.get("warm_start_threshold", _SENTINEL),
)
_set_from_nested(
"hybrid_enable_adaptive_warmup_lr",
raw_hybrid_streaming.get("enable_adaptive_warmup_lr", _SENTINEL),
)
_set_from_nested(
"hybrid_warmup_lr_refinement",
raw_hybrid_streaming.get("warmup_lr_refinement", _SENTINEL),
)
_set_from_nested(
"hybrid_warmup_lr_careful",
raw_hybrid_streaming.get("warmup_lr_careful", _SENTINEL),
)
_set_from_nested(
"hybrid_enable_cost_guard",
raw_hybrid_streaming.get("enable_cost_guard", _SENTINEL),
)
_set_from_nested(
"hybrid_cost_increase_tolerance",
raw_hybrid_streaming.get("cost_increase_tolerance", _SENTINEL),
)
_set_from_nested(
"hybrid_enable_step_clipping",
raw_hybrid_streaming.get("enable_step_clipping", _SENTINEL),
)
_set_from_nested(
"hybrid_max_warmup_step_size",
raw_hybrid_streaming.get("max_warmup_step_size", _SENTINEL),
)
elif raw_hybrid_streaming is not None:
logger.warning(
"NLSQConfig.from_dict: 'hybrid_streaming' must be a dict, got %r — ignoring",
type(raw_hybrid_streaming).__name__,
)
# Homodyne-parity: multi_start section
raw_multi_start = config.get("multi_start")
if isinstance(raw_multi_start, dict):
_set_from_nested(
"enable_multi_start", raw_multi_start.get("enable", _SENTINEL)
)
_set_from_nested(
"multi_start_n_starts", raw_multi_start.get("n_starts", _SENTINEL)
)
_set_from_nested("multi_start_seed", raw_multi_start.get("seed", _SENTINEL))
_set_from_nested(
"multi_start_sampling_strategy",
raw_multi_start.get("sampling_strategy", _SENTINEL),
)
_set_from_nested(
"multi_start_n_workers", raw_multi_start.get("n_workers", _SENTINEL)
)
_set_from_nested(
"multi_start_use_screening",
raw_multi_start.get("use_screening", _SENTINEL),
)
_set_from_nested(
"multi_start_screen_keep_fraction",
raw_multi_start.get("screen_keep_fraction", _SENTINEL),
)
_set_from_nested(
"multi_start_refine_top_k",
raw_multi_start.get("refine_top_k", _SENTINEL),
)
_set_from_nested(
"multi_start_refinement_ftol",
raw_multi_start.get("refinement_ftol", _SENTINEL),
)
_set_from_nested(
"multi_start_degeneracy_threshold",
raw_multi_start.get("degeneracy_threshold", _SENTINEL),
)
elif raw_multi_start is not None:
logger.warning(
"NLSQConfig.from_dict: 'multi_start' must be a dict, got %r — ignoring",
type(raw_multi_start).__name__,
)
# Homodyne-parity: quality_validation section
raw_quality = config.get("quality_validation")
if isinstance(raw_quality, dict):
_set_from_nested(
"enable_quality_validation", raw_quality.get("enable", _SENTINEL)
)
_set_from_nested(
"quality_reduced_chi_squared_threshold",
raw_quality.get("reduced_chi_squared_threshold", _SENTINEL),
)
_set_from_nested(
"quality_warn_on_max_restarts",
raw_quality.get("warn_on_max_restarts", _SENTINEL),
)
_set_from_nested(
"quality_warn_on_bounds_hit",
raw_quality.get("warn_on_bounds_hit", _SENTINEL),
)
_set_from_nested(
"quality_warn_on_convergence_failure",
raw_quality.get("warn_on_convergence_failure", _SENTINEL),
)
_set_from_nested(
"quality_bounds_tolerance",
raw_quality.get("bounds_tolerance", _SENTINEL),
)
elif raw_quality is not None:
logger.warning(
"NLSQConfig.from_dict: 'quality_validation' must be a dict, got %r — ignoring",
type(raw_quality).__name__,
)
nested_keys = {
"recovery",
"validation",
"x_scale_map",
"anti_degeneracy",
"cmaes",
"progress",
"hybrid_streaming",
"multi_start",
"quality_validation",
}
# Warn on unrecognised keys
all_known = set(known_scalar_fields) | nested_keys
for key in normalized_config:
if key not in all_known:
logger.warning(
"NLSQConfig.from_dict: unrecognised key %r — ignoring", key
)
kwargs: dict[str, Any] = {}
# --- Parse scalar fields -----------------------------------------
for field_name, kind in known_scalar_fields.items():
raw = normalized_config.get(field_name, _SENTINEL)
if raw is _SENTINEL:
continue # use dataclass default
if kind == "float":
kwargs[field_name] = safe_float(raw, 0.0)
elif kind == "int":
kwargs[field_name] = safe_int(raw, 0)
elif kind == "bool":
kwargs[field_name] = bool(raw)
elif kind == "str":
kwargs[field_name] = str(raw)
elif kind == "float_or_none":
kwargs[field_name] = None if raw is None else safe_float(raw, 0.0)
elif kind == "int_or_none":
kwargs[field_name] = None if raw is None else safe_int(raw, 0)
elif kind == "str_or_none":
kwargs[field_name] = None if raw is None else str(raw)
elif kind == "passthrough":
kwargs[field_name] = raw
# no else branch needed — exhaustive set above
# --- Parse x_scale_map -------------------------------------------
raw_scale_map = normalized_config.get("x_scale_map")
if isinstance(raw_scale_map, dict):
kwargs["x_scale_map"] = {
str(k): safe_float(v, 1.0) for k, v in raw_scale_map.items()
}
elif raw_scale_map is not None:
logger.warning(
"NLSQConfig.from_dict: x_scale_map must be a dict, got %r — ignoring",
type(raw_scale_map).__name__,
)
# --- Parse nested recovery sub-dict ------------------------------
raw_recovery = normalized_config.get("recovery")
if isinstance(raw_recovery, dict):
recovery = HybridRecoveryConfig(
max_retries=safe_int(
raw_recovery.get("max_retries"), HybridRecoveryConfig.max_retries
),
lr_decay=safe_float(
raw_recovery.get("lr_decay"), HybridRecoveryConfig.lr_decay
),
lambda_growth=safe_float(
raw_recovery.get("lambda_growth"),
HybridRecoveryConfig.lambda_growth,
),
trust_decay=safe_float(
raw_recovery.get("trust_decay"), HybridRecoveryConfig.trust_decay
),
perturb_scale=safe_float(
raw_recovery.get("perturb_scale"),
HybridRecoveryConfig.perturb_scale,
),
)
kwargs["recovery_config"] = recovery
elif raw_recovery is not None:
logger.warning(
"NLSQConfig.from_dict: 'recovery' must be a dict, got %r — ignoring",
type(raw_recovery).__name__,
)
# --- Parse nested validation sub-dict ----------------------------
raw_validation = normalized_config.get("validation")
if isinstance(raw_validation, dict):
defaults = NLSQValidationConfig()
validation = NLSQValidationConfig(
chi2_warn_low=safe_float(
raw_validation.get("chi2_warn_low"), defaults.chi2_warn_low
),
chi2_warn_high=safe_float(
raw_validation.get("chi2_warn_high"), defaults.chi2_warn_high
),
chi2_fail_high=safe_float(
raw_validation.get("chi2_fail_high"), defaults.chi2_fail_high
),
max_relative_uncertainty=safe_float(
raw_validation.get("max_relative_uncertainty"),
defaults.max_relative_uncertainty,
),
correlation_warn=safe_float(
raw_validation.get("correlation_warn"), defaults.correlation_warn
),
)
kwargs["validation"] = validation
elif raw_validation is not None:
logger.warning(
"NLSQConfig.from_dict: 'validation' must be a dict, got %r — ignoring",
type(raw_validation).__name__,
)
return cls(**kwargs)
[docs]
def to_dict(self) -> dict[str, Any]:
"""Serialise the configuration to a plain dictionary.
Nested dataclasses are serialised as nested dicts, making the output
suitable for round-tripping through YAML / JSON.
Returns:
Fully populated dictionary representation.
"""
return {
# Core solver
"max_iterations": self.max_iterations,
"tolerance": self.tolerance,
"method": self.method,
"multistart": self.multistart,
"multistart_n": self.multistart_n,
"verbose": self.verbose,
"use_jac": self.use_jac,
"x_scale": self.x_scale,
"ftol": self.ftol,
"xtol": self.xtol,
"gtol": self.gtol,
"loss": self.loss,
"diff_step": self.diff_step,
"max_nfev": self.max_nfev,
"chunk_size": self.chunk_size,
# Workflow / goal
"workflow": self.workflow,
"goal": self.goal,
# Streaming / stratified
"enable_streaming": self.enable_streaming,
"streaming_chunk_size": self.streaming_chunk_size,
"enable_stratified": self.enable_stratified,
"target_chunk_size": self.target_chunk_size,
# Recovery
"enable_recovery": self.enable_recovery,
"max_recovery_attempts": self.max_recovery_attempts,
"recovery": {
"max_retries": self.recovery_config.max_retries,
"lr_decay": self.recovery_config.lr_decay,
"lambda_growth": self.recovery_config.lambda_growth,
"trust_decay": self.recovery_config.trust_decay,
"perturb_scale": self.recovery_config.perturb_scale,
},
# Diagnostics
"enable_diagnostics": self.enable_diagnostics,
"enable_anti_degeneracy": self.enable_anti_degeneracy,
# Loss / scaling
"x_scale_map": dict(self.x_scale_map),
"loss_weights": self.loss_weights,
"loss_scale": self.loss_scale,
"tr_solver": self.tr_solver,
"step_bound": self.step_bound,
# Fourier reparameterization
"per_angle_mode": self.per_angle_mode,
"fourier_order": self.fourier_order,
"fourier_auto_threshold": self.fourier_auto_threshold,
# Hierarchical optimization
"enable_hierarchical": self.enable_hierarchical,
"hierarchical_max_outer_iterations": self.hierarchical_max_outer_iterations,
"hierarchical_inner_tolerance": self.hierarchical_inner_tolerance,
"hierarchical_outer_tolerance": self.hierarchical_outer_tolerance,
# Adaptive regularization
"regularization_mode": self.regularization_mode,
"group_variance_lambda": self.group_variance_lambda,
"regularization_target_cv": self.regularization_target_cv,
# Gradient collapse detection
"enable_gradient_monitoring": self.enable_gradient_monitoring,
"gradient_ratio_threshold": self.gradient_ratio_threshold,
"gradient_consecutive_triggers": self.gradient_consecutive_triggers,
# CMA-ES global search
"enable_cmaes": self.enable_cmaes,
"cmaes_sigma0": self.cmaes_sigma0,
"cmaes_max_iterations": self.cmaes_max_iterations,
"cmaes_population_size": self.cmaes_population_size,
"cmaes_tolx": self.cmaes_tolx,
"cmaes_tolfun": self.cmaes_tolfun,
"cmaes_diagonal_filtering": self.cmaes_diagonal_filtering,
"cmaes_anti_degeneracy": self.cmaes_anti_degeneracy,
"cmaes_warmstart_auto_skip": self.cmaes_warmstart_auto_skip,
"cmaes_warmstart_skip_threshold": self.cmaes_warmstart_skip_threshold,
"cmaes_restart_strategy": self.cmaes_restart_strategy,
"cmaes_max_restarts": self.cmaes_max_restarts,
# Hybrid streaming optimizer
"hybrid_enable": self.hybrid_enable,
"hybrid_warmup_fraction": self.hybrid_warmup_fraction,
"hybrid_normalization": self.hybrid_normalization,
"hybrid_method": self.hybrid_method,
"hybrid_lbfgs_memory": self.hybrid_lbfgs_memory,
"hybrid_convergence_window": self.hybrid_convergence_window,
"hybrid_convergence_threshold": self.hybrid_convergence_threshold,
"hybrid_max_phases": self.hybrid_max_phases,
# Multi-start extensions
"sampling_strategy": self.sampling_strategy,
"screen_keep_fraction": self.screen_keep_fraction,
"refine_top_k": self.refine_top_k,
# Scaling threshold
"constant_scaling_threshold": self.constant_scaling_threshold,
# Backend / model
"use_nlsq_library": self.use_nlsq_library,
"n_params": self.n_params,
"analysis_mode": self.analysis_mode,
# NLSQ package integration
"nlsq_stability": self.nlsq_stability,
"nlsq_rescale_data": self.nlsq_rescale_data,
"nlsq_x_scale": self.nlsq_x_scale,
"nlsq_memory_fraction": self.nlsq_memory_fraction,
"nlsq_memory_fallback_gb": self.nlsq_memory_fallback_gb,
# Validation
"validation": {
"chi2_warn_low": self.validation.chi2_warn_low,
"chi2_warn_high": self.validation.chi2_warn_high,
"chi2_fail_high": self.validation.chi2_fail_high,
"max_relative_uncertainty": self.validation.max_relative_uncertainty,
"correlation_warn": self.validation.correlation_warn,
},
}
[docs]
@classmethod
def from_yaml(cls, yaml_path: str) -> NLSQConfig:
"""Create NLSQConfig from YAML configuration file.
This is the recommended single entry point for loading NLSQ
configuration. It reads the YAML file, extracts the
``optimization.nlsq`` section, and creates a validated
``NLSQConfig`` object.
Args:
yaml_path: Path to YAML configuration file.
Returns:
Validated ``NLSQConfig`` instance.
Raises:
FileNotFoundError: If the YAML file does not exist.
ValueError: If the YAML file is invalid or missing required
sections.
Examples:
>>> config = NLSQConfig.from_yaml("heterodyne_config.yaml")
>>> print(config.loss)
soft_l1
"""
from pathlib import Path
import yaml
path = Path(yaml_path)
if not path.exists():
raise FileNotFoundError(f"Configuration file not found: {yaml_path}")
with open(path, encoding="utf-8") as f:
full_config = yaml.safe_load(f)
if full_config is None:
full_config = {}
# Extract optimization.nlsq section
optimization = full_config.get("optimization", {})
nlsq_config = optimization.get("nlsq", {})
if not nlsq_config:
logger.warning(
"No optimization.nlsq section found in %s, using defaults",
yaml_path,
)
return cls.from_dict(nlsq_config)
[docs]
def is_valid(self) -> bool:
"""Check if the configuration is valid.
Returns:
``True`` if ``validate()`` returns an empty list, ``False``
otherwise.
"""
return len(self.validate()) == 0
[docs]
def to_workflow_kwargs(self) -> dict[str, Any]:
"""Convert settings to kwargs for NLSQ's ``curve_fit()``.
Maps ``NLSQConfig`` settings to NLSQ 0.6.10+ ``curve_fit()``
parameters. Heterodyne uses ``curve_fit()`` directly rather than
the unified ``fit()`` API.
Returns:
Dictionary of kwargs suitable for passing to ``curve_fit()``:
``ftol``, ``gtol``, ``xtol``, ``max_nfev``, ``loss``, and
optionally ``goal``.
Notes:
NLSQ 0.6.3+ workflows: ``"auto"``, ``"auto_global"``,
``"hpc"``. Old presets (``"streaming"``, ``"standard"``)
were removed. Heterodyne uses its own strategy selection for
memory-aware dispatch, so ``workflow`` is not forwarded.
"""
kwargs: dict[str, Any] = {}
# goal can be passed to NLSQ's fit() API; omit the default value
# to avoid overriding NLSQ's own default.
if self.goal != "quality":
kwargs["goal"] = self.goal
# Convergence settings — directly supported by curve_fit()
kwargs["ftol"] = self.ftol
kwargs["gtol"] = self.gtol
kwargs["xtol"] = self.xtol
kwargs["max_nfev"] = self.max_iterations
# Loss function
kwargs["loss"] = self.loss
return kwargs