Source code for heterodyne.optimization.nlsq.core

"""Core NLSQ fitting for heterodyne analysis.

Unified entry point for NLSQ optimization with:
- Global optimization selection (CMA-ES → multi-start → local)
- Adapter/wrapper fallback with automatic recovery
- Memory-aware strategy selection
- Per-angle and multi-angle fitting
"""

from __future__ import annotations

import time
from typing import TYPE_CHECKING, Any

import jax.numpy as jnp
import numpy as np

from heterodyne.core.jax_backend import (
    compute_c2_heterodyne,
    compute_multi_angle_residuals,
    compute_residuals,
)
from heterodyne.optimization.nlsq.config import NLSQConfig
from heterodyne.optimization.nlsq.results import NLSQResult
from heterodyne.utils.logging import get_logger

if TYPE_CHECKING:
    # Static-analyser imports: give Pyright/mypy the names unconditionally
    # so call sites guarded by ``HAS_X`` flags don't trip "possibly
    # unbound" diagnostics. Runtime behaviour is controlled by the
    # try/except blocks below.
    from heterodyne.core.heterodyne_model import HeterodyneModel
    from heterodyne.optimization.nlsq.adapter import (
        LowLevelNLSQWrapper as NLSQWrapper,
    )
    from heterodyne.optimization.nlsq.adapter import NLSQAdapter
    from heterodyne.optimization.nlsq.cmaes_wrapper import fit_with_cmaes
    from heterodyne.optimization.nlsq.memory import (
        NLSQStrategy,
        select_nlsq_strategy,
    )
    from heterodyne.optimization.nlsq.multistart import (
        MultiStartConfig,
        MultiStartOptimizer,
    )

logger = get_logger(__name__)

# ---------------------------------------------------------------------------
# Optional imports — gated for graceful degradation
# ---------------------------------------------------------------------------

try:
    # NLSQWrapper here is the LOW-LEVEL wrapper (adapter.NLSQWrapper);
    # alias it explicitly to avoid shadowing wrapper.NLSQWrapper (the public
    # high-level stable-fallback re-exported from heterodyne.optimization.nlsq).
    from heterodyne.optimization.nlsq.adapter import (  # noqa: F811
        LowLevelNLSQWrapper as NLSQWrapper,
    )
    from heterodyne.optimization.nlsq.adapter import NLSQAdapter  # noqa: F811

    HAS_ADAPTERS = True
    HAS_WRAPPER = True
except ImportError:
    # Call sites are guarded by ``HAS_ADAPTERS`` / ``HAS_WRAPPER``; the
    # names above remain unbound but are visible to static analysers via
    # the ``TYPE_CHECKING`` block.
    HAS_ADAPTERS = False
    HAS_WRAPPER = False

try:
    from heterodyne.optimization.nlsq.multistart import (  # noqa: F811
        MultiStartConfig,
        MultiStartOptimizer,
    )

    HAS_MULTISTART = True
except ImportError:
    HAS_MULTISTART = False

try:
    from heterodyne.optimization.nlsq.cmaes_wrapper import (
        CMAES_AVAILABLE,
        fit_with_cmaes,  # noqa: F811
    )

    HAS_CMAES = CMAES_AVAILABLE
except ImportError:
    HAS_CMAES = False

try:
    from heterodyne.optimization.nlsq.memory import (  # noqa: F811
        NLSQStrategy,
        select_nlsq_strategy,
    )

    HAS_MEMORY = True
except ImportError:
    HAS_MEMORY = False

# Export availability flag for tests
NLSQ_AVAILABLE = HAS_ADAPTERS


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


[docs] def fit_nlsq_jax( model: HeterodyneModel, c2_data: np.ndarray | jnp.ndarray, phi_angle: float = 0.0, config: NLSQConfig | None = None, weights: np.ndarray | jnp.ndarray | None = None, use_nlsq_library: bool = True, *, _skip_global_selection: bool = False, ) -> NLSQResult: """Fit heterodyne model to correlation data using NLSQ. This is the unified entry point for all NLSQ optimization. When called it first checks for global optimization methods: 1. If ``cmaes.enable: true`` → delegates to CMA-ES 2. If ``multi_start.enable: true`` → delegates to multi-start 3. Otherwise → runs local trust-region optimization The adapter is tried first; on failure the wrapper provides automatic retry with progressive recovery (HybridRecoveryConfig). Args: model: HeterodyneModel instance with parameters configured. c2_data: Experimental correlation data, shape (N, N). phi_angle: Detector phi angle (degrees). config: NLSQ configuration (default if None). weights: Optional weights (1/sigma²) for weighted least squares. use_nlsq_library: Whether to prefer nlsq library over scipy. _skip_global_selection: Internal flag — skip CMA-ES / multi-start check. Returns: NLSQResult with fitted parameters and diagnostics. """ if config is None: config = NLSQConfig() logger.info("=" * 60) logger.info("NLSQ OPTIMIZATION") logger.info("=" * 60) logger.info("phi=%s°, method=%s", phi_angle, config.method) # ------------------------------------------------------------------ # Input validation (homodyne parity — non-strict observer). # Emits WARNINGs on shape/bounds/finiteness issues but does not # block the fit. The fit proceeds either way; strict mode is # opt-in via the validator's constructor. # ------------------------------------------------------------------ _run_input_validation(model=model, c2_data=c2_data) # ------------------------------------------------------------------ # Global optimization selection (CMA-ES → multi-start → local) # ------------------------------------------------------------------ if not _skip_global_selection: global_result = _try_global_optimization( model, c2_data, phi_angle, config, weights, use_nlsq_library, ) if global_result is not None: _run_result_validation(result=global_result, model=model) return global_result logger.debug("No global optimization enabled, using local optimization") # ------------------------------------------------------------------ # Local optimization # ------------------------------------------------------------------ result = _fit_local(model, c2_data, phi_angle, config, weights, use_nlsq_library) _run_result_validation(result=result, model=model) return result
[docs] def fit_nlsq_multi_phi( model: HeterodyneModel, c2_data: np.ndarray, phi_angles: list[float] | np.ndarray, config: NLSQConfig | None = None, weights: np.ndarray | None = None, ) -> list[NLSQResult]: """Fit model to correlation data at multiple phi angles. Two modes of operation controlled by ``config.per_angle_mode``: - **Joint fit** (``"fourier"``, ``"independent"``, or ``"auto"`` with multiple angles) -- All angles are fit simultaneously in a single optimization. In ``"fourier"`` mode, the optimizer vector is ``[physics_varying | fourier_contrast_coeffs | fourier_offset_coeffs]``, where the Fourier basis constrains smooth angular variation. In ``"independent"`` mode, each angle has its own contrast/offset (``2*n_phi`` scaling parameters), all optimized jointly. - **Sequential mode** (single angle or fallback) -- Angles are fit one at a time with warm-starting. Args: model: HeterodyneModel instance. c2_data: Correlation data, shape ``(n_phi, N, N)`` or ``(N, N)``. phi_angles: Array of phi angles (degrees). config: NLSQ configuration. weights: Optional weights, shape ``(n_phi, N, N)`` or ``(N, N)``. Returns: List of :class:`NLSQResult`, one per angle. """ phi_angles = np.asarray(phi_angles) if c2_data.ndim == 2: c2_data = c2_data[np.newaxis, ...] if len(c2_data) != len(phi_angles): raise ValueError( f"Number of c2 matrices ({len(c2_data)}) doesn't match " f"number of phi angles ({len(phi_angles)})" ) # ------------------------------------------------------------------ # Anti-degeneracy diagnostics (homodyne parity). # # The AntiDegeneracyController is the ported homodyne 4-layer defense # class. Heterodyne's joint-fit path (``_fit_joint_multi_phi``) is the # actual orchestrator, but the controller is consulted here as an # active observer so its mode-resolution logic and group-variance # diagnostics are exercised on every multi-angle fit. This keeps the # ported class on the production hot path rather than shelf-ware. # ------------------------------------------------------------------ _log_anti_degeneracy_diagnostics( config=config, phi_angles=phi_angles, ) # ------------------------------------------------------------------ # Determine whether to use homodyne-style joint multi-angle fitting. # ------------------------------------------------------------------ use_joint = False fourier: Any = None if config is not None and len(phi_angles) > 1: if getattr(config, "enable_cmaes", False) and HAS_CMAES: logger.info("CMA-ES enabled, delegating to joint multi-angle CMA-ES") return _fit_joint_cmaes_multi_phi( model=model, c2_data=c2_data, phi_angles=phi_angles, config=config, weights=weights, ) constant_threshold = max( int(getattr(config, "constant_scaling_threshold", 3)), 1 ) if _use_fixed_constant_scaling_mode(config, len(phi_angles)): logger.info( "Fixed-constant scaling selected (homodyne 'constant' parity): " "per-angle β,o frozen from quantile, n_phi=%d " "(joint fit of 14 physics only)", len(phi_angles), ) return _fit_joint_fixed_constant_multi_phi( model=model, c2_data=c2_data, phi_angles=phi_angles, config=config, weights=weights, ) if _use_averaged_constant_scaling_mode(config, len(phi_angles)): logger.info( "Auto averaged scaling selected: mode=%s, n_phi=%d, threshold=%d " "(joint fit of 14 physics + 2 averaged β,o = 16 params)", config.per_angle_mode, len(phi_angles), constant_threshold, ) return _fit_joint_averaged_multi_phi( model=model, c2_data=c2_data, phi_angles=phi_angles, config=config, weights=weights, ) try: from heterodyne.optimization.nlsq.fourier_reparam import ( FourierReparamConfig, FourierReparameterizer, ) fourier_config = FourierReparamConfig( mode=config.per_angle_mode, fourier_order=config.fourier_order, auto_threshold=config.fourier_auto_threshold, ) phi_rad = np.deg2rad(phi_angles.astype(np.float64)) fourier = FourierReparameterizer(phi_rad, fourier_config) use_joint = True except ImportError: logger.warning( "fourier_reparam not available, falling back to sequential fits" ) if use_joint: assert config is not None # use_joint=True implies config was non-None assert fourier is not None return _fit_joint_multi_phi( model, c2_data, phi_angles, config, weights, fourier, ) # ------------------------------------------------------------------ # Sequential per-angle fitting (warm-start chain) # ------------------------------------------------------------------ results = [] for i, phi in enumerate(phi_angles): if i > 0: logger.info( "Fitting phi angle %d/%d: %s° (warm-start from angle %s°)", i + 1, len(phi_angles), phi, phi_angles[i - 1], ) else: logger.info("Fitting phi angle %d/%d: %s°", i + 1, len(phi_angles), phi) c2_i = c2_data[i] weights_i = weights[i] if weights is not None and weights.ndim == 3 else weights result = fit_nlsq_jax( model=model, c2_data=c2_i, phi_angle=float(phi), config=config, weights=weights_i, ) result.metadata["phi_angle"] = float(phi) results.append(result) return results
# --------------------------------------------------------------------------- # Public global-optimization entry points (parity with homodyne) # ---------------------------------------------------------------------------
[docs] def fit_nlsq_cmaes( model: HeterodyneModel, c2_data: np.ndarray | jnp.ndarray, phi_angle: float = 0.0, config: NLSQConfig | None = None, weights: np.ndarray | jnp.ndarray | None = None, ) -> NLSQResult: """CMA-ES global optimization for multi-scale parameter problems. Public entry point that delegates to the internal ``_fit_cmaes`` implementation. Raises ``ImportError`` if CMA-ES is not available and ``ValueError`` if CMA-ES is not enabled in *config*. Args: model: HeterodyneModel instance with parameters configured. c2_data: Experimental correlation data, shape (N, N). phi_angle: Detector phi angle (degrees). config: NLSQ configuration. Defaults to ``NLSQConfig()``. weights: Optional weights (1/σ²) for weighted least squares. Returns: NLSQResult with fitted parameters and diagnostics. Raises: ImportError: If CMA-ES (``cma`` package) is not available. ValueError: If CMA-ES is not enabled in *config*. Examples: >>> config = NLSQConfig(enable_cmaes=True) >>> result = fit_nlsq_cmaes(model, c2_data, config=config) >>> print(f"Chi2: {result.reduced_chi_squared:.4f}") """ if not HAS_CMAES: raise ImportError("CMA-ES requires the 'cma' package. Install with: uv add cma") if config is None: config = NLSQConfig() if not getattr(config, "enable_cmaes", False): raise ValueError( "CMA-ES optimization is not enabled. Set enable_cmaes=True in NLSQConfig." ) return _fit_cmaes(model, c2_data, phi_angle, config, weights)
[docs] def fit_nlsq_multistart( model: HeterodyneModel, c2_data: np.ndarray | jnp.ndarray, phi_angle: float = 0.0, config: NLSQConfig | None = None, weights: np.ndarray | jnp.ndarray | None = None, use_nlsq_library: bool = True, ) -> NLSQResult: """Multi-start NLSQ optimization with Latin Hypercube Sampling. Public entry point that delegates to the internal ``_fit_multistart`` implementation. Explores the parameter space to avoid local minima. FULL strategy is always used — no subsampling. Args: model: HeterodyneModel instance with parameters configured. c2_data: Experimental correlation data, shape (N, N). phi_angle: Detector phi angle (degrees). config: NLSQ configuration. Defaults to ``NLSQConfig()``. weights: Optional weights (1/σ²) for weighted least squares. use_nlsq_library: Whether to prefer nlsq library over scipy. Returns: NLSQResult with best parameters across all starts. Raises: ImportError: If the multi-start module is not available. ValueError: If multi-start is not enabled in *config*. Examples: >>> config = NLSQConfig(multistart=True, multistart_n=20) >>> result = fit_nlsq_multistart(model, c2_data, config=config) >>> print(f"Best chi2: {result.reduced_chi_squared:.4f}") """ if not HAS_MULTISTART: raise ImportError( "Multi-start optimization requires the multistart module. " "Ensure heterodyne.optimization.nlsq.multistart is importable." ) if config is None: config = NLSQConfig() if not getattr(config, "multistart", False): raise ValueError( "Multi-start optimization is not enabled. " "Set multistart=True in NLSQConfig." ) return _fit_multistart(model, c2_data, phi_angle, config, weights, use_nlsq_library)
def _compute_per_angle_chi2( residuals: np.ndarray, c2_matrix: np.ndarray, n_params: int, ) -> tuple[float, float]: """Compute per-angle cost and noise-normalised reduced chi-squared. Joint fits produce one aggregated cost and chi2 for all angles. This helper reconstructs the per-angle statistics so each NLSQResult carries its own diagnostics rather than a copy of the joint value. Args: residuals: Flat off-diagonal residual vector from compute_residuals, length n*(n-1). c2_matrix: Per-angle experimental C2 matrix, shape (n, n). n_params: Number of varying physics parameters. Returns: ``(per_angle_cost, reduced_chi_squared)`` where ``per_angle_cost`` is ``0.5*SSR`` and ``reduced_chi_squared`` is noise-normalised (target ≈ 1.0 for a good fit; MSE fallback when noise is degenerate). """ ssr = float(np.sum(residuals**2)) per_angle_cost = 0.5 * ssr n_matrix = c2_matrix.shape[0] n_valid = (n_matrix - 1) * ( n_matrix - 2 ) # valid off-diagonal non-boundary residuals n_dof = max(n_valid - n_params, 1) # Far-lag photon-noise estimate — same formula as _fit_local c2_np = np.asarray(c2_matrix) row_idx = np.arange(n_matrix) lag_mat = np.abs(row_idx[:, None] - row_idx[None, :]) far_vals = c2_np[lag_mat >= n_matrix // 2] sigma2_noise = float(np.var(far_vals)) if far_vals.size > 1 else 0.0 if sigma2_noise > 1e-12: reduced_chi2 = ssr / (sigma2_noise * n_dof) else: reduced_chi2 = ssr / n_dof # MSE fallback return per_angle_cost, reduced_chi2 def _anti_degen_dict_from_config(config: NLSQConfig) -> dict[str, Any]: """Project ``NLSQConfig`` fields onto the ``AntiDegeneracyConfig`` schema. Mirrors the shape consumed by :meth:`AntiDegeneracyController.from_config` so the joint-fit code paths can construct a controller from the active ``NLSQConfig`` without duplicating the projection logic. Returns a nested ``dict`` matching the YAML structure under ``optimization.nlsq.anti_degeneracy``. """ return { "enable": config.enable_anti_degeneracy, "per_angle_mode": config.per_angle_mode, "fourier_order": config.fourier_order, "fourier_auto_threshold": config.fourier_auto_threshold, "constant_scaling_threshold": getattr(config, "constant_scaling_threshold", 3), "hierarchical": { "enable": config.enable_hierarchical, "max_outer_iterations": config.hierarchical_max_outer_iterations, "outer_tolerance": config.hierarchical_outer_tolerance, "physical_max_iterations": config.hierarchical_physical_max_iterations, "per_angle_max_iterations": config.hierarchical_per_angle_max_iterations, }, "regularization": { "mode": config.regularization_mode, "lambda": config.group_variance_lambda, "target_cv": config.regularization_target_cv, "target_contribution": config.regularization_target_contribution, "max_cv": config.regularization_max_cv, }, "gradient_monitoring": { "enable": config.enable_gradient_monitoring, "ratio_threshold": config.gradient_ratio_threshold, "consecutive_triggers": config.gradient_consecutive_triggers, "response": config.gradient_collapse_response, }, } def _build_anti_degen_controller( *, config: NLSQConfig, n_phi: int, phi_angles: np.ndarray, n_physical: int, ) -> Any | None: """Construct an ``AntiDegeneracyController`` when any active layer is requested. Sub-PR C3: broadened again to also build whenever L4 gradient-collapse monitoring is enabled. L4 is observation-only (the monitor records per-iteration gradient norms and exposes a summary via ``controller.monitor.get_summary()``) and is meaningful even for the fixed-constant path where no per-angle scaling block exists. Sub-PR C2: broadened from the C1 marker-only guard. The controller is built whenever anti-degeneracy is enabled and EITHER L2 hierarchical OR L3 regularization OR L4 gradient monitoring is requested. L1 (Fourier) is dispatched through a separate joint-fit path and does not require a controller here. Returns ``None`` when no active defense layer is requested, when the fit is single-angle, or when controller construction fails. Callers are responsible for deriving the L2 marker, L3 callbacks, and L4 monitor summary from the returned controller. """ if not config.enable_anti_degeneracy: return None if not ( config.enable_hierarchical or config.regularization_mode != "none" or config.enable_gradient_monitoring ): return None if n_phi <= 1: return None try: from heterodyne.optimization.nlsq.anti_degeneracy_controller import ( AntiDegeneracyController, ) return AntiDegeneracyController.from_config( config_dict=_anti_degen_dict_from_config(config), n_phi=n_phi, phi_angles=np.deg2rad(np.asarray(phi_angles, dtype=np.float64)), n_physical=n_physical, per_angle_scaling=True, ) except Exception as exc: # noqa: BLE001 — controller failure must not derail fit logger.warning( "Anti-degeneracy controller construction skipped: %s", exc, ) return None def _hierarchical_marker_from_controller( controller: Any | None, ) -> dict[str, Any] | None: """Derive the L2 hierarchical-request marker dict from a built controller. Sub-PR C1 marker (now plumbed through C2's broadened controller). Returns ``None`` if the controller did not enable hierarchical mode, otherwise returns the dict embedded in ``result.metadata["hierarchical_config"]``. ``HierarchicalFitter`` is a single-angle, stage-based parameter- unfreezing strategy (transport → velocity → fraction → all) that requires an ``NLSQAdapterBase`` at construction time — it is **not** the multi-phi outer/inner alternation implied by ``max_outer_iterations``. Wiring it directly into the joint residual path is out of scope here, so this is observe-only. TODO(C1-followup): replace this marker with an active outer/inner loop once the controller learns to drive the joint residual via callbacks (see ``create_nlsq_callbacks`` comment in ``anti_degeneracy_controller.py``). """ if controller is None or not getattr(controller, "use_hierarchical", False): return None cfg_dict = dict(controller._hierarchical_config_dict) logger.info( "L2 hierarchical requested but not actively wired (see C1 follow-up); " "marker: max_outer_iterations=%d, outer_tolerance=%.1e", cfg_dict.get("max_outer_iterations", -1), cfg_dict.get("outer_tolerance", float("nan")), ) return { "requested": True, "active": False, "max_outer_iterations": int(cfg_dict.get("max_outer_iterations", 0)), "outer_tolerance": float(cfg_dict.get("outer_tolerance", 0.0)), "physical_max_iterations": int(cfg_dict.get("physical_max_iterations", 0)), "per_angle_max_iterations": int(cfg_dict.get("per_angle_max_iterations", 0)), } def _fit_joint_averaged_multi_phi( model: HeterodyneModel, c2_data: np.ndarray, phi_angles: np.ndarray, config: NLSQConfig, weights: np.ndarray | None, ) -> list[NLSQResult]: """Joint multi-angle fit with AVERAGED contrast/offset scaling. Homodyne parity for ``per_angle_mode="auto"`` (when n_phi >= threshold): per-angle quantile estimates are computed first, averaged to one contrast and one offset, and those two scalars are optimized jointly with the 14 physics parameters (16 total). This is distinct from :func:`_fit_joint_fixed_constant_multi_phi` (added in Sub-PR B), which implements the true homodyne ``"constant"`` semantics by FREEZING per-angle β,o and optimizing only the 14 physics parameters. """ from heterodyne.config.parameter_registry import SCALING_PARAMS from heterodyne.core.scaling_utils import compute_averaged_scaling t_start = time.perf_counter() param_manager = model.param_manager varying_names = list(param_manager.varying_names) n_physics_varying = param_manager.n_varying n_phi = len(phi_angles) physics_initial = np.asarray(param_manager.get_initial_values(), dtype=np.float64) physics_lower, physics_upper = param_manager.get_bounds() physics_initial = np.clip(physics_initial, physics_lower, physics_upper) t = model.t q = model.q dt = model.dt t1_mesh, t2_mesh = np.meshgrid(np.asarray(t), np.asarray(t), indexing="ij") n_time_points = t1_mesh.size c2_flat = [] t1_flat = [] t2_flat = [] phi_indices = [] for i in range(n_phi): c2_flat.append(np.asarray(c2_data[i], dtype=np.float64).reshape(-1)) t1_flat.append(t1_mesh.reshape(-1)) t2_flat.append(t2_mesh.reshape(-1)) phi_indices.append(np.full(n_time_points, i, dtype=np.int32)) contrast_bounds = ( SCALING_PARAMS["contrast"].min_bound, SCALING_PARAMS["contrast"].max_bound, ) offset_bounds = ( SCALING_PARAMS["offset"].min_bound, SCALING_PARAMS["offset"].max_bound, ) logger.info("=" * 60) logger.info("AUTO AVERAGED SCALING: Computing per-angle scaling from quantiles") logger.info("=" * 60) avg_contrast, avg_offset, contrast_per_angle, offset_per_angle = ( compute_averaged_scaling( c2_data=np.concatenate(c2_flat), t1=np.concatenate(t1_flat), t2=np.concatenate(t2_flat), phi_indices=np.concatenate(phi_indices), n_phi=n_phi, contrast_bounds=contrast_bounds, offset_bounds=offset_bounds, log=logger, ) ) x0 = np.concatenate([physics_initial, [avg_contrast, avg_offset]]) lb = np.concatenate([physics_lower, [contrast_bounds[0], offset_bounds[0]]]) ub = np.concatenate([physics_upper, [contrast_bounds[1], offset_bounds[1]]]) joint_param_names = [*varying_names, "contrast", "offset"] logger.info( "Joint auto averaged fit: %d physical + 2 averaged scaling = %d total params, %d angles", n_physics_varying, len(x0), n_phi, ) # Anti-degeneracy controller construction (Sub-PR C1 + C2). # Built whenever EITHER L2 hierarchical OR L3 regularization is requested. anti_degen_controller = _build_anti_degen_controller( config=config, n_phi=n_phi, phi_angles=np.asarray(phi_angles, dtype=np.float64), n_physical=n_physics_varying, ) hierarchical_marker = _hierarchical_marker_from_controller(anti_degen_controller) c2_data_batch = jnp.asarray(c2_data, dtype=jnp.float64) weights_batch = ( jnp.asarray(weights, dtype=jnp.float64) if weights is not None else jnp.ones_like(c2_data_batch) ) if weights_batch.ndim == 2: weights_batch = jnp.broadcast_to(weights_batch, c2_data_batch.shape) phi_angles_jax = jnp.asarray(phi_angles, dtype=jnp.float64) fixed_values_jax = jnp.asarray(param_manager.get_full_values(), dtype=jnp.float64) varying_indices_jax = jnp.array(param_manager.varying_indices, dtype=jnp.int32) # NOTE: must return a JAX array. NLSQ's masked_residual_func JIT-traces this # closure; np.asarray() on a traced result raises TracerArrayConversionError. def joint_residual_fn(x: np.ndarray) -> Any: # type: ignore[return-value] physics_varying = x[:n_physics_varying] contrast = x[n_physics_varying] offset = x[n_physics_varying + 1] full_jax = fixed_values_jax.at[varying_indices_jax].set( jnp.asarray(physics_varying, dtype=jnp.float64) ) contrasts_jax = jnp.full((n_phi,), contrast, dtype=jnp.float64) offsets_jax = jnp.full((n_phi,), offset, dtype=jnp.float64) return compute_multi_angle_residuals( full_jax, t, q, dt, phi_angles_jax, c2_data_batch, weights_batch, contrasts_jax, offsets_jax, ) # L3 adaptive regularization wiring (Sub-PR C2). # When regularization_mode != "none", append a single Tikhonov penalty # row to the residual vector so the active fit penalises per-angle # scaling variance via the controller's loss-augmentation callback. if ( anti_degen_controller is not None and getattr(anti_degen_controller, "regularizer", None) is not None and config.regularization_mode != "none" ): callbacks = anti_degen_controller.create_nlsq_callbacks() loss_aug = callbacks.get("loss_augmentation") if loss_aug is not None: _inner_residual_fn = joint_residual_fn def joint_residual_fn_with_penalty(x: np.ndarray) -> Any: # type: ignore[return-value] """Residual + appended penalty row from controller's loss_aug.""" base = _inner_residual_fn(x) base_np = np.asarray(base) penalty_value = float(loss_aug(np.asarray(x), base_np)) penalty_row = float(np.sqrt(max(2.0 * penalty_value, 0.0))) return jnp.concatenate( [ jnp.asarray(base, dtype=jnp.float64), jnp.array([penalty_row], dtype=jnp.float64), ] ) joint_residual_fn = joint_residual_fn_with_penalty logger.info( "L3 adaptive regularization active (mode=%s, lambda=%.4e)", config.regularization_mode, config.group_variance_lambda, ) joint_config = NLSQConfig( method=config.method if config.method != "lm" else "trf", ftol=config.ftol, xtol=config.xtol, gtol=config.gtol, max_nfev=(config.max_nfev * n_phi if config.max_nfev is not None else None), loss=config.loss, use_nlsq_library=config.use_nlsq_library, n_params=len(x0), ) joint_result: NLSQResult | None = None if HAS_ADAPTERS: try: joint_adapter = NLSQAdapter(parameter_names=joint_param_names) # pyright: ignore[reportPossiblyUnbound] joint_result = joint_adapter.fit( residual_fn=joint_residual_fn, initial_params=x0, bounds=(lb, ub), config=joint_config, ) if not joint_result.success: raise RuntimeError( f"Joint adapter returned success=False: {joint_result.message}" ) except (ValueError, RuntimeError, TypeError) as adapter_exc: logger.warning( "Joint auto averaged NLSQAdapter failed, falling back to NLSQWrapper: %s", adapter_exc, ) joint_result = None if joint_result is None and HAS_WRAPPER: joint_wrapper = NLSQWrapper(parameter_names=joint_param_names) # pyright: ignore[reportPossiblyUnbound] joint_result = joint_wrapper.fit( residual_fn=joint_residual_fn, initial_params=x0, bounds=(lb, ub), config=joint_config, ) if joint_result is None: raise ImportError( "No NLSQ backend available for joint auto averaged multi-angle fit." ) # L4 gradient collapse monitor wiring (Sub-PR C3). monitor_summary: dict[str, Any] = {} if ( anti_degen_controller is not None and getattr(anti_degen_controller, "monitor", None) is not None and config.enable_gradient_monitoring ): # TODO(L4-followup): The current adapter loop does not expose a # per-iteration gradient callback to the controller, so # monitor.get_summary() is observation-passive — it returns empty # history. To activate L4 we need either (a) the NLSQAdapter to # call back into the controller via iteration_callback from # create_nlsq_callbacks(), or (b) post-fit gradient-norm sampling # driven by the joint solve's iteration trace. The metadata key # 'gradient_monitor' is currently set only as a placeholder to # surface that L4 was REQUESTED. try: monitor_summary = dict(anti_degen_controller.monitor.get_summary() or {}) except (AttributeError, TypeError) as exc: logger.debug("L4 monitor summary unavailable: %s", exc) monitor_summary = {} if monitor_summary: logger.info("L4 gradient monitor summary: %s", monitor_summary) fitted_all = np.asarray(joint_result.parameters, dtype=np.float64) fitted_physics = fitted_all[:n_physics_varying] fitted_contrast = float(fitted_all[n_physics_varying]) fitted_offset = float(fitted_all[n_physics_varying + 1]) full_fitted = param_manager.expand_varying_to_full(fitted_physics) model.set_params(full_fitted) if hasattr(model, "scaling"): model.scaling.contrast[:] = fitted_contrast model.scaling.offset[:] = fitted_offset wall_time = time.perf_counter() - t_start results: list[NLSQResult] = [] for i, phi in enumerate(phi_angles): fitted_c2 = compute_c2_heterodyne( jnp.asarray(full_fitted), t, q, dt, float(phi), contrast=fitted_contrast, offset=fitted_offset, ) residuals = np.asarray( compute_residuals( jnp.asarray(full_fitted), t, q, dt, float(phi), c2_data_batch[i], weights_batch[i], contrast=fitted_contrast, offset=fitted_offset, ) ) per_angle_cost, per_angle_chi2 = _compute_per_angle_chi2( residuals, np.asarray(c2_data_batch[i]), n_physics_varying ) result = NLSQResult( parameters=fitted_physics.copy(), parameter_names=varying_names, uncertainties=( joint_result.uncertainties[:n_physics_varying].copy() if joint_result.uncertainties is not None else None ), covariance=( joint_result.covariance[:n_physics_varying, :n_physics_varying].copy() if joint_result.covariance is not None else None ), residuals=residuals, final_cost=per_angle_cost, reduced_chi_squared=per_angle_chi2, success=bool(joint_result.success), message=str(joint_result.message), n_iterations=joint_result.n_iterations, n_function_evals=joint_result.n_function_evals, convergence_reason=joint_result.convergence_reason, fitted_correlation=np.asarray(fitted_c2), wall_time_seconds=joint_result.wall_time_seconds, metadata={ "phi_angle": float(phi), "contrast": fitted_contrast, "offset": fitted_offset, "contrast_initial_quantile": float(contrast_per_angle[i]), "offset_initial_quantile": float(offset_per_angle[i]), "contrast_initial_average": avg_contrast, "offset_initial_average": avg_offset, "optimizer": "joint_auto_averaged", "n_angles_joint": n_phi, "wall_time_total": wall_time, **( {"hierarchical_config": hierarchical_marker} if hierarchical_marker is not None else {} ), **({"gradient_monitor": monitor_summary} if monitor_summary else {}), }, ) results.append(result) logger.info( "Joint auto averaged fit complete: success=%s, cost=%.6f, " "n_evals=%d, wall_time=%.2fs, %d angles", joint_result.success, joint_result.final_cost or 0.0, joint_result.n_function_evals or 0, wall_time, n_phi, ) return results def _fit_joint_fixed_constant_multi_phi( model: HeterodyneModel, c2_data: np.ndarray, phi_angles: np.ndarray, config: NLSQConfig, weights: np.ndarray | None, ) -> list[NLSQResult]: """Joint multi-angle fit with FROZEN per-angle contrast/offset. Homodyne parity for ``per_angle_mode="constant"``: per-angle β(φ) and o(φ) are estimated once from quantile analysis (5 %/95 % of g2 per angle) and held fixed. The optimizer fits only the 14 physics parameters. Distinct from :func:`_fit_joint_averaged_multi_phi`, which AVERAGES per-angle estimates to a single scalar β̄, ō and optimizes them jointly with physics (16 params total — the homodyne ``"auto"`` behaviour). """ from heterodyne.config.parameter_registry import SCALING_PARAMS from heterodyne.core.quantile_scaling import compute_per_angle_quantile_scaling t_start = time.perf_counter() param_manager = model.param_manager varying_names = list(param_manager.varying_names) n_physics_varying = param_manager.n_varying n_phi = len(phi_angles) physics_initial = np.asarray(param_manager.get_initial_values(), dtype=np.float64) physics_lower, physics_upper = param_manager.get_bounds() physics_initial = np.clip(physics_initial, physics_lower, physics_upper) contrast_bounds = ( SCALING_PARAMS["contrast"].min_bound, SCALING_PARAMS["contrast"].max_bound, ) offset_bounds = ( SCALING_PARAMS["offset"].min_bound, SCALING_PARAMS["offset"].max_bound, ) # Per-angle quantile estimates (the FREEZE step). g2_flat_parts = [] phi_idx_parts = [] for i in range(n_phi): flat = np.asarray(c2_data[i], dtype=np.float64).reshape(-1) g2_flat_parts.append(flat) phi_idx_parts.append(np.full(flat.size, i, dtype=np.int32)) quantile = compute_per_angle_quantile_scaling( g2=np.concatenate(g2_flat_parts), phi_indices=np.concatenate(phi_idx_parts), n_phi=n_phi, contrast_bounds=contrast_bounds, offset_bounds=offset_bounds, ) contrast_per_angle = quantile.contrast_per_angle offset_per_angle = quantile.offset_per_angle logger.info("=" * 60) logger.info( "FIXED CONSTANT scaling (homodyne 'constant' parity): " "per-angle β,o frozen from quantile, n_phi=%d", n_phi, ) logger.info( " contrast range: [%.4f, %.4f]; offset range: [%.4f, %.4f]", float(contrast_per_angle.min()), float(contrast_per_angle.max()), float(offset_per_angle.min()), float(offset_per_angle.max()), ) logger.info("=" * 60) # Optimizer sees only the physics parameters; β,o enter the residual # closure as captured constants. x0 = physics_initial.copy() lb = np.asarray(physics_lower, dtype=np.float64).copy() ub = np.asarray(physics_upper, dtype=np.float64).copy() joint_param_names = list(varying_names) logger.info( "Joint fixed-constant fit: %d physics params (frozen β,o), %d angles", n_physics_varying, n_phi, ) # L4 gradient collapse monitor wiring (Sub-PR C3). # Fixed-constant has no per-angle scaling block, so L2/L3 don't apply # (per_angle_scaling=False), but L4 is observation-only and meaningful # for the physics-only optimization. anti_degen_controller: Any = None if config.enable_anti_degeneracy and config.enable_gradient_monitoring: try: from heterodyne.optimization.nlsq.anti_degeneracy_controller import ( AntiDegeneracyController, ) anti_degen_controller = AntiDegeneracyController.from_config( config_dict=_anti_degen_dict_from_config(config), n_phi=n_phi, phi_angles=np.deg2rad(np.asarray(phi_angles, dtype=np.float64)), n_physical=n_physics_varying, per_angle_scaling=False, # fixed-constant has no per-angle block ) except Exception as exc: # noqa: BLE001 — never derail the fit logger.warning("Fixed-constant L4 controller construction skipped: %s", exc) anti_degen_controller = None t = model.t q = model.q dt = model.dt c2_data_list = [jnp.asarray(c2_data[i], dtype=jnp.float64) for i in range(n_phi)] weights_list: list[jnp.ndarray | None] = [] for i in range(n_phi): if weights is not None and weights.ndim == 3: weights_list.append(jnp.asarray(weights[i], dtype=jnp.float64)) elif weights is not None: weights_list.append(jnp.asarray(weights, dtype=jnp.float64)) else: weights_list.append(None) c2_data_batch = jnp.stack(c2_data_list, axis=0) weights_batch = jnp.stack( [ (w if w is not None else jnp.ones_like(c2_data_list[i])) for i, w in enumerate(weights_list) ], axis=0, ) contrast_per_angle_j = jnp.asarray(contrast_per_angle, dtype=jnp.float64) offset_per_angle_j = jnp.asarray(offset_per_angle, dtype=jnp.float64) def joint_residual_fn(physics_only: np.ndarray) -> np.ndarray: """Residual closure: β,o are FIXED per-angle constants.""" full = jnp.asarray(physics_only, dtype=jnp.float64) all_res = [] for i, phi in enumerate(phi_angles): res_i = compute_residuals( full, t, q, dt, float(phi), c2_data_batch[i], weights_batch[i], contrast=float(contrast_per_angle_j[i]), offset=float(offset_per_angle_j[i]), ) all_res.append(np.asarray(res_i).reshape(-1)) return np.concatenate(all_res) joint_config = NLSQConfig( method=config.method if config.method != "lm" else "trf", ftol=config.ftol, xtol=config.xtol, gtol=config.gtol, max_nfev=(config.max_nfev * n_phi if config.max_nfev is not None else None), loss=config.loss, use_nlsq_library=config.use_nlsq_library, n_params=len(x0), ) joint_result: NLSQResult | None = None if HAS_ADAPTERS: try: joint_adapter = NLSQAdapter(parameter_names=joint_param_names) # pyright: ignore[reportPossiblyUnbound] joint_result = joint_adapter.fit( residual_fn=joint_residual_fn, initial_params=x0, bounds=(lb, ub), config=joint_config, ) if not joint_result.success: raise RuntimeError( f"Fixed-constant joint adapter returned " f"success=False: {joint_result.message}" ) except (ValueError, RuntimeError, TypeError) as adapter_exc: logger.warning( "Fixed-constant NLSQAdapter failed, falling back to wrapper: %s", adapter_exc, ) joint_result = None if joint_result is None and HAS_WRAPPER: joint_wrapper = NLSQWrapper(parameter_names=joint_param_names) # pyright: ignore[reportPossiblyUnbound] joint_result = joint_wrapper.fit( residual_fn=joint_residual_fn, initial_params=x0, bounds=(lb, ub), config=joint_config, ) if joint_result is None: raise ImportError( "No NLSQ backend available for fixed-constant joint multi-angle fit." ) # L4 gradient collapse monitor wiring (Sub-PR C3). monitor_summary: dict[str, Any] = {} if ( anti_degen_controller is not None and getattr(anti_degen_controller, "monitor", None) is not None and config.enable_gradient_monitoring ): # TODO(L4-followup): The current adapter loop does not expose a # per-iteration gradient callback to the controller, so # monitor.get_summary() is observation-passive — it returns empty # history. To activate L4 we need either (a) the NLSQAdapter to # call back into the controller via iteration_callback from # create_nlsq_callbacks(), or (b) post-fit gradient-norm sampling # driven by the joint solve's iteration trace. The metadata key # 'gradient_monitor' is currently set only as a placeholder to # surface that L4 was REQUESTED. try: monitor_summary = dict(anti_degen_controller.monitor.get_summary() or {}) except (AttributeError, TypeError) as exc: logger.debug("L4 monitor summary unavailable: %s", exc) monitor_summary = {} if monitor_summary: logger.info("L4 gradient monitor summary: %s", monitor_summary) fitted_physics = np.asarray(joint_result.parameters, dtype=np.float64) full_fitted = param_manager.expand_varying_to_full(fitted_physics) model.set_params(full_fitted) if hasattr(model, "scaling"): # Surface the frozen per-angle β,o on the model for downstream viz. try: model.scaling.contrast[:] = contrast_per_angle model.scaling.offset[:] = offset_per_angle except (ValueError, TypeError): # Length mismatch (e.g. scaling was a scalar holder): replace. model.scaling.contrast = np.asarray(contrast_per_angle, dtype=np.float64) model.scaling.offset = np.asarray(offset_per_angle, dtype=np.float64) wall_time = time.perf_counter() - t_start results: list[NLSQResult] = [] for i, phi in enumerate(phi_angles): # Post-review backfill: populate fitted_correlation, residuals, # reduced_chi_squared per-angle (parity with _fit_joint_averaged_multi_phi). # KEY DIFFERENCE from averaged path: β,o are per-angle frozen constants, # not averaged scalars. fitted_c2 = compute_c2_heterodyne( jnp.asarray(full_fitted), t, q, dt, float(phi), contrast=float(contrast_per_angle[i]), offset=float(offset_per_angle[i]), ) residuals_i = np.asarray( compute_residuals( jnp.asarray(full_fitted), t, q, dt, float(phi), c2_data_batch[i], weights_batch[i], contrast=float(contrast_per_angle[i]), offset=float(offset_per_angle[i]), ) ) per_angle_cost_i, per_angle_chi2_i = _compute_per_angle_chi2( residuals_i, np.asarray(c2_data_batch[i]), n_physics_varying ) per_angle_result = NLSQResult( parameters=fitted_physics.copy(), parameter_names=list(varying_names), success=joint_result.success, message=joint_result.message, covariance=joint_result.covariance, residuals=residuals_i, final_cost=per_angle_cost_i, reduced_chi_squared=per_angle_chi2_i, fitted_correlation=np.asarray(fitted_c2), n_iterations=joint_result.n_iterations, n_function_evals=joint_result.n_function_evals, wall_time_seconds=wall_time, metadata={ **dict(joint_result.metadata or {}), "phi_angle": float(phi), "per_angle_mode_actual": "fixed_constant", "contrast_fixed": float(contrast_per_angle[i]), "offset_fixed": float(offset_per_angle[i]), "optimizer": "joint_fixed_constant", **({"gradient_monitor": monitor_summary} if monitor_summary else {}), }, ) results.append(per_angle_result) return results def _fit_joint_cmaes_multi_phi( model: HeterodyneModel, c2_data: np.ndarray, phi_angles: np.ndarray, config: NLSQConfig, weights: np.ndarray | None, ) -> list[NLSQResult]: """Joint multi-angle CMA-ES with NLSQ warm-start and auto-skip. This mirrors homodyne's CMA-ES procedure at the orchestration level: first run the joint NLSQ path, optionally skip global search when the warm-start is already good, otherwise run CMA-ES and keep the lower-cost result. """ from heterodyne.config.parameter_registry import SCALING_PARAMS from heterodyne.optimization.nlsq.cmaes_wrapper import CMAESConfig use_fixed_constant = _use_fixed_constant_scaling_mode(config, len(phi_angles)) use_averaged_constant = _use_averaged_constant_scaling_mode(config, len(phi_angles)) # legacy alias retained for downstream CMA-ES setup (residual closure, # bounds construction, result reconstruction) which all share the # averaged-style search-space layout regardless of warmstart path. use_constant = use_fixed_constant or use_averaged_constant fourier = ( None if use_constant else _build_fourier_reparameterizer( phi_angles, config, ) ) if use_fixed_constant: warmstart_label = "fixed-constant (per-angle β,o frozen)" elif use_averaged_constant: warmstart_label = "averaged constant (β̄,ō jointly fit)" else: warmstart_label = "fourier/independent" logger.info("=" * 60) logger.info("CMA-ES GLOBAL OPTIMIZATION") logger.info("=" * 60) logger.info("Analysis mode: %s", config.analysis_mode) logger.info( "Anti-degeneracy scaling mode: %s%s", warmstart_label, f" ({config.per_angle_mode})", ) if use_fixed_constant: warmstart_results = _fit_joint_fixed_constant_multi_phi( model=model, c2_data=c2_data, phi_angles=phi_angles, config=config, weights=weights, ) elif use_averaged_constant: warmstart_results = _fit_joint_averaged_multi_phi( model=model, c2_data=c2_data, phi_angles=phi_angles, config=config, weights=weights, ) else: warmstart_results = _fit_joint_multi_phi( model=model, c2_data=c2_data, phi_angles=phi_angles, config=config, weights=weights, fourier=fourier, ) first = warmstart_results[0] warmstart_cost = ( float(first.final_cost) if first.final_cost is not None else float("inf") ) warmstart_reduced_chi2 = ( float(first.reduced_chi_squared) if first.reduced_chi_squared is not None else float("inf") ) logger.info( "[CMA-ES] NLSQ warm-start succeeded: cost=%.4e, reduced chi2=%.4f", warmstart_cost, warmstart_reduced_chi2, ) auto_skip = bool(getattr(config, "cmaes_warmstart_auto_skip", True)) skip_threshold = float(getattr(config, "cmaes_warmstart_skip_threshold", 5.0)) if auto_skip and warmstart_reduced_chi2 < skip_threshold: logger.info( "[CMA-ES] Auto-skip: NLSQ warm-start reduced chi2=%.4f < threshold=%.1f. " "Skipping CMA-ES global search.", warmstart_reduced_chi2, skip_threshold, ) for result in warmstart_results: result.metadata["optimizer"] = "joint_cmaes_warmstart_auto_skip" result.metadata["cmaes_skipped"] = True result.metadata["warmstart_reduced_chi2"] = warmstart_reduced_chi2 return warmstart_results param_manager = model.param_manager varying_names = list(param_manager.varying_names) n_physics_varying = param_manager.n_varying n_phi = len(phi_angles) physics_lower, physics_upper = param_manager.get_bounds() if use_constant: contrast_bounds = ( SCALING_PARAMS["contrast"].min_bound, SCALING_PARAMS["contrast"].max_bound, ) offset_bounds = ( SCALING_PARAMS["offset"].min_bound, SCALING_PARAMS["offset"].max_bound, ) scaling_lower = np.array( [contrast_bounds[0], offset_bounds[0]], dtype=np.float64, ) scaling_upper = np.array( [contrast_bounds[1], offset_bounds[1]], dtype=np.float64, ) scaling_initial = np.array( [ float(first.metadata.get("contrast", 0.3)), float(first.metadata.get("offset", 1.0)), ], dtype=np.float64, ) scaling_names = ["contrast", "offset"] else: assert fourier is not None contrast_initial = np.array( [ float(result.metadata.get("contrast", 0.3)) for result in warmstart_results ], dtype=np.float64, ) offset_initial = np.array( [float(result.metadata.get("offset", 1.0)) for result in warmstart_results], dtype=np.float64, ) scaling_initial = fourier.per_angle_to_fourier( contrast_initial, offset_initial, ) scaling_lower, scaling_upper = fourier.get_bounds() scaling_names = fourier.get_coefficient_labels() bounds = ( np.concatenate([physics_lower, scaling_lower]), np.concatenate([physics_upper, scaling_upper]), ) initial_params = np.concatenate( [np.asarray(first.parameters, dtype=np.float64), scaling_initial] ) parameter_names = [*varying_names, *scaling_names] c2_data_batch = jnp.asarray(c2_data, dtype=jnp.float64) weights_batch = ( jnp.asarray(weights, dtype=jnp.float64) if weights is not None else jnp.ones_like(c2_data_batch) ) if weights_batch.ndim == 2: weights_batch = jnp.broadcast_to(weights_batch, c2_data_batch.shape) t = model.t q = model.q dt = model.dt phi_angles_jax = jnp.asarray(phi_angles, dtype=jnp.float64) fixed_values_jax = jnp.asarray(param_manager.get_full_values(), dtype=jnp.float64) varying_indices_jax = jnp.array(param_manager.varying_indices, dtype=jnp.int32) # CMA-ES phase searches in averaged-style space (14 physics + 2 averaged β̄,ō); # fixed-constant warmstart hands off seeds in this space. # NOTE: must return a JAX array. NLSQ's masked_residual_func JIT-traces this # closure; np.asarray() on a traced result raises TracerArrayConversionError. def residual_fn(x: np.ndarray) -> Any: # type: ignore[return-value] physics_varying = x[:n_physics_varying] full_jax = fixed_values_jax.at[varying_indices_jax].set( jnp.asarray(physics_varying, dtype=jnp.float64) ) scaling_params = x[n_physics_varying:] if use_constant: contrast = scaling_params[0] offset = scaling_params[1] contrasts_jax = jnp.full((n_phi,), contrast, dtype=jnp.float64) offsets_jax = jnp.full((n_phi,), offset, dtype=jnp.float64) else: assert fourier is not None contrast_arr, offset_arr = fourier.fourier_to_per_angle(scaling_params) contrasts_jax = jnp.asarray(contrast_arr, dtype=jnp.float64) offsets_jax = jnp.asarray(offset_arr, dtype=jnp.float64) return compute_multi_angle_residuals( full_jax, t, q, dt, phi_angles_jax, c2_data_batch, weights_batch, contrasts_jax, offsets_jax, ) def objective_fn(x: np.ndarray) -> float: residuals = residual_fn(x) return float(0.5 * np.sum(residuals**2)) logger.info("[CMA-ES] Phase 2: Running CMA-ES global optimization...") n_time = int(c2_data_batch.shape[-1]) n_off_diagonal_data = int(n_phi * (n_time - 1) * (n_time - 2)) restart_strategy = getattr(config, "cmaes_restart_strategy", "bipop") max_restarts = getattr(config, "cmaes_max_restarts", 9) # Warmstart is always active in this path: BIPOP large-population restarts # are incoherent with a tight initial sigma derived from the NLSQ solution. if restart_strategy == "bipop": restart_strategy = "none" max_restarts = 0 logger.debug( "[CMA-ES] Warm-start active: overriding restart_strategy='bipop' -> 'none' " "(BIPOP large-population restarts are incoherent with small sigma_warmstart)" ) cmaes_result = fit_with_cmaes( # pyright: ignore[reportPossiblyUnbound] objective_fn=objective_fn, initial_params=initial_params, bounds=bounds, parameter_names=parameter_names, config=CMAESConfig( sigma0=config.cmaes_sigma0, popsize=config.cmaes_population_size, maxiter=config.cmaes_max_iterations, tolx=config.cmaes_tolx, tolfun=config.cmaes_tolfun, diagonal_filtering=getattr(config, "cmaes_diagonal_filtering", "none"), restart_strategy=restart_strategy, max_restarts=max_restarts, ), residual_fn=residual_fn, n_data=n_off_diagonal_data, anti_degeneracy=getattr(config, "cmaes_anti_degeneracy", False), ) cmaes_cost = ( float(cmaes_result.final_cost) if cmaes_result.final_cost is not None else float("inf") ) if warmstart_cost <= cmaes_cost: logger.info( "[CMA-ES] NLSQ warm-start result is better: NLSQ cost=%.4e < CMA-ES cost=%.4e. " "Using NLSQ solution.", warmstart_cost, cmaes_cost, ) for result in warmstart_results: result.metadata["optimizer"] = "joint_cmaes_warmstart" result.metadata["cmaes_cost"] = cmaes_cost result.metadata["nlsq_warmstart_cost"] = warmstart_cost return warmstart_results logger.info( "[CMA-ES] CMA-ES result is better: CMA-ES cost=%.4e <= NLSQ cost=%.4e", cmaes_cost, warmstart_cost, ) fitted = np.asarray(cmaes_result.parameters, dtype=np.float64) fitted_physics = fitted[:n_physics_varying] fitted_scaling = fitted[n_physics_varying:] if use_constant: fitted_contrast = np.full(n_phi, float(fitted_scaling[0]), dtype=np.float64) fitted_offset = np.full(n_phi, float(fitted_scaling[1]), dtype=np.float64) else: assert fourier is not None fitted_contrast, fitted_offset = fourier.fourier_to_per_angle(fitted_scaling) full_fitted = param_manager.expand_varying_to_full(fitted_physics) model.set_params(full_fitted) if hasattr(model, "scaling"): model.scaling.contrast[:] = fitted_contrast model.scaling.offset[:] = fitted_offset results: list[NLSQResult] = [] for i, phi in enumerate(phi_angles): contrast_i = float(fitted_contrast[i]) offset_i = float(fitted_offset[i]) fitted_c2 = compute_c2_heterodyne( jnp.asarray(full_fitted), t, q, dt, float(phi), contrast=contrast_i, offset=offset_i, ) residuals = np.asarray( compute_residuals( jnp.asarray(full_fitted), t, q, dt, float(phi), c2_data_batch[i], weights_batch[i], contrast=contrast_i, offset=offset_i, ) ) metadata = { "phi_angle": float(phi), "contrast": contrast_i, "offset": offset_i, "optimizer": "joint_cmaes", "n_angles_joint": n_phi, "cmaes_cost": cmaes_cost, "nlsq_warmstart_cost": warmstart_cost, } if use_constant: metadata["anti_degeneracy_mode"] = "constant_averaged" else: assert fourier is not None metadata.update( { "anti_degeneracy_mode": fourier.config.mode, "fourier_mode": fourier.config.mode, "fourier_order": fourier.order, "fourier_coeffs": fitted_scaling.tolist(), "fourier_n_coeffs": fourier.n_coeffs, "fourier_reduction": fourier.get_diagnostics()["reduction_ratio"], } ) results.append( NLSQResult( parameters=fitted_physics.copy(), parameter_names=varying_names, uncertainties=( cmaes_result.uncertainties[:n_physics_varying].copy() if cmaes_result.uncertainties is not None else None ), covariance=( cmaes_result.covariance[ :n_physics_varying, :n_physics_varying ].copy() if cmaes_result.covariance is not None else None ), residuals=residuals, final_cost=cmaes_result.final_cost, reduced_chi_squared=cmaes_result.reduced_chi_squared, success=bool(cmaes_result.success), message=str(cmaes_result.message), n_iterations=cmaes_result.n_iterations, n_function_evals=cmaes_result.n_function_evals, convergence_reason=cmaes_result.convergence_reason, fitted_correlation=np.asarray(fitted_c2), wall_time_seconds=cmaes_result.wall_time_seconds, metadata=metadata, ) ) return results def _use_fixed_constant_scaling_mode(config: NLSQConfig, n_phi: int) -> bool: """True when per-angle β,o must be FROZEN (homodyne `constant` parity).""" del n_phi # explicit mode is threshold-independent return config.per_angle_mode == "constant" def _use_averaged_constant_scaling_mode(config: NLSQConfig, n_phi: int) -> bool: """True when β,o are averaged across angles and OPTIMIZED (homodyne `auto`).""" constant_threshold = max(int(getattr(config, "constant_scaling_threshold", 3)), 1) return config.per_angle_mode == "auto" and n_phi >= constant_threshold def _use_constant_scaling_mode(config: NLSQConfig, n_phi: int) -> bool: """Legacy union predicate; retained for backward compatibility. True if EITHER the fixed-constant or averaged-constant path applies. New code should call the specific predicate; this wrapper exists so legacy importers do not break. """ return _use_fixed_constant_scaling_mode( config, n_phi ) or _use_averaged_constant_scaling_mode(config, n_phi) def _run_input_validation( *, model: HeterodyneModel, c2_data: np.ndarray | jnp.ndarray, ) -> None: """Run :class:`InputValidator` against the fit inputs (non-strict). Bounds, finiteness, and initial-param-within-bounds checks. Any failures are logged at WARNING; the fit is not blocked. Strict enforcement is opt-in via the validator constructor and is not used here so the validation never breaks an otherwise-healthy fit. Failures inside the validator itself are themselves caught — input validation is an observer, not part of the hot path. """ try: from heterodyne.optimization.nlsq.validation import InputValidator initial = np.asarray(model.param_manager.get_initial_values(), dtype=np.float64) lower, upper = model.param_manager.get_bounds() bounds = ( np.asarray(lower, dtype=np.float64), np.asarray(upper, dtype=np.float64), ) data = np.asarray(c2_data, dtype=np.float64) report = InputValidator(strict_mode=False).validate( data=data, initial_params=initial, bounds=bounds, ) if not report.is_valid: logger.warning( "Input validation flagged %d issue(s) (continuing fit)", len(report.issues), ) except Exception as exc: # noqa: BLE001 — observer must never block logger.debug("Input validation skipped: %s", exc) def _run_result_validation( *, result: NLSQResult, model: HeterodyneModel, ) -> None: """Run :class:`ResultValidator` against the optimization result (non-strict).""" try: from heterodyne.optimization.nlsq.validation import ResultValidator if result.parameters is None: return params = np.asarray(result.parameters, dtype=np.float64) lower, upper = model.param_manager.get_bounds() bounds = ( np.asarray(lower, dtype=np.float64), np.asarray(upper, dtype=np.float64), ) cov = ( np.asarray(result.covariance, dtype=np.float64) if result.covariance is not None else None ) ResultValidator(strict_mode=False).validate_all( params=params, covariance=cov, bounds=bounds, chi_squared=result.final_cost, ) except Exception as exc: # noqa: BLE001 — observer must never block logger.debug("Result validation skipped: %s", exc) def _log_anti_degeneracy_diagnostics( *, config: NLSQConfig | None, phi_angles: np.ndarray, ) -> None: """Construct an ``AntiDegeneracyController`` and log its diagnostics. The controller's mode-resolution (auto → fourier / auto_averaged / individual / fixed_constant) and group-variance indices are computed here and emitted at DEBUG level. The class is the ported homodyne 4-layer defense; heterodyne's actual joint-fit path lives in ``_fit_joint_multi_phi``, but the controller is exercised here as an active observer so its parity behaviour is verified on every multi-angle run. Failures (import errors, missing config fields, malformed phi arrays) are caught and logged at WARNING — the diagnostics path must never derail an otherwise-healthy fit. """ if config is None or len(phi_angles) <= 1: return try: from heterodyne.optimization.nlsq.anti_degeneracy_controller import ( AntiDegeneracyController, ) ad_config_dict: dict[str, Any] = { "enable": True, "per_angle_mode": getattr(config, "per_angle_mode", "auto"), "constant_scaling_threshold": int( getattr(config, "constant_scaling_threshold", 3) ), "fourier_auto_threshold": int( getattr(config, "fourier_auto_threshold", 10) ), "fourier_order": int(getattr(config, "fourier_order", 1)), } controller = AntiDegeneracyController.from_config( config_dict=ad_config_dict, n_phi=len(phi_angles), phi_angles=np.deg2rad(phi_angles.astype(np.float64)), per_angle_scaling=True, ) diag = controller.get_diagnostics() logger.debug( "AntiDegeneracyController diagnostics: mode=%s, n_phi=%d, " "per_angle_mode_actual=%s, group_indices=%s", diag.get("per_angle_mode"), diag.get("n_phi"), diag.get("per_angle_mode_actual"), controller.get_group_variance_indices(), ) except Exception as exc: # noqa: BLE001 — diagnostics must not block fits logger.warning("Anti-degeneracy diagnostics skipped: %s", exc) def _build_fourier_reparameterizer(phi_angles: np.ndarray, config: NLSQConfig) -> Any: """Build the Fourier/independent reparameterizer for fallback paths.""" from heterodyne.optimization.nlsq.fourier_reparam import ( FourierReparamConfig, FourierReparameterizer, ) return FourierReparameterizer( np.deg2rad(phi_angles.astype(np.float64)), FourierReparamConfig( mode=config.per_angle_mode, fourier_order=config.fourier_order, auto_threshold=config.fourier_auto_threshold, ), ) def _fit_joint_multi_phi( model: HeterodyneModel, c2_data: np.ndarray, phi_angles: np.ndarray, config: NLSQConfig, weights: np.ndarray | None, fourier: Any, ) -> list[NLSQResult]: """Joint multi-angle fit with Fourier-parameterized scaling. The optimizer parameter vector is: [physics_varying_params | fourier_contrast_coeffs | fourier_offset_coeffs] The residual function evaluates all angles, using the Fourier basis to convert coefficients → per-angle contrast/offset at each evaluation. This is the heterodyne equivalent of homodyne's AntiDegeneracyController joint-fit path. """ t_start = time.perf_counter() param_manager = model.param_manager varying_names = param_manager.varying_names n_physics_varying = param_manager.n_varying n_phi = len(phi_angles) # Physics parameter initial values and bounds physics_initial = param_manager.get_initial_values() physics_lower, physics_upper = param_manager.get_bounds() physics_initial = np.clip(physics_initial, physics_lower, physics_upper) # Fourier coefficient initial values and bounds scaling = model.scaling contrast_init = float(scaling.contrast[0]) if len(scaling.contrast) > 0 else 0.5 offset_init = float(scaling.offset[0]) if len(scaling.offset) > 0 else 1.0 fourier_initial = fourier.get_initial_coefficients(contrast_init, offset_init) fourier_lower, fourier_upper = fourier.get_bounds() # Combined parameter vector x0 = np.concatenate([physics_initial, fourier_initial]) lb = np.concatenate([physics_lower, fourier_lower]) ub = np.concatenate([physics_upper, fourier_upper]) logger.info( "Joint multi-angle fit: %d physics + %d Fourier = %d total params, %d angles", n_physics_varying, fourier.n_coeffs, len(x0), n_phi, ) # Anti-degeneracy controller construction (Sub-PR C1 + C2). # Built whenever EITHER L2 hierarchical OR L3 regularization is requested. anti_degen_controller = _build_anti_degen_controller( config=config, n_phi=n_phi, phi_angles=np.asarray(phi_angles, dtype=np.float64), n_physical=n_physics_varying, ) hierarchical_marker = _hierarchical_marker_from_controller(anti_degen_controller) # Pre-convert data to JAX arrays (outside closure — constants) t, q, dt = model.t, model.q, model.dt c2_data_list = [jnp.asarray(c2_data[i], dtype=jnp.float64) for i in range(n_phi)] weights_list: list[jnp.ndarray | None] = [] for i in range(n_phi): if weights is not None and weights.ndim == 3: weights_list.append(jnp.asarray(weights[i], dtype=jnp.float64)) elif weights is not None: weights_list.append(jnp.asarray(weights, dtype=jnp.float64)) else: weights_list.append(None) # Pre-stack batched arrays for compute_multi_angle_residuals. # weights_list entries may be None (unweighted) — materialise ones_like # so the stacked weights_batch is always a concrete (n_phi, N, N) array. c2_data_batch = jnp.stack(c2_data_list, axis=0) # (n_phi, N, N) weights_batch = jnp.stack( [ (w if w is not None else jnp.ones_like(c2_data_list[i])) for i, w in enumerate(weights_list) ], axis=0, ) # (n_phi, N, N) phi_angles_jax = jnp.asarray(phi_angles, dtype=jnp.float64) # (n_phi,) fixed_values_jax = jnp.asarray(param_manager.get_full_values(), dtype=jnp.float64) varying_indices_jax = jnp.array(param_manager.varying_indices, dtype=jnp.int32) def joint_residual_fn(x: np.ndarray) -> np.ndarray: """Compute concatenated residuals across all angles via vmap. Routes through ``compute_multi_angle_residuals`` (jit + vmap) to replace the previous n_phi serial kernel dispatches with a single batched XLA call. Fourier reparameterization is preserved: the combined parameter vector is split into physics and Fourier parts, and ``fourier.fourier_to_per_angle`` converts coefficients to per-angle contrast/offset arrays before the batched residual call. """ # Split combined vector physics_varying = x[:n_physics_varying] fourier_coeffs = x[n_physics_varying:] # Reconstruct full physics parameter array (immutable JAX scatter) varying_jax = jnp.asarray(physics_varying, dtype=jnp.float64) full_jax = fixed_values_jax.at[varying_indices_jax].set(varying_jax) # Convert Fourier coefficients → per-angle contrast/offset contrast_arr, offset_arr = fourier.fourier_to_per_angle(fourier_coeffs) contrasts_jax = jnp.asarray(contrast_arr, dtype=jnp.float64) # (n_phi,) offsets_jax = jnp.asarray(offset_arr, dtype=jnp.float64) # (n_phi,) # Single batched vmap call — eliminates n_phi serial dispatches return np.asarray( compute_multi_angle_residuals( full_jax, t, q, dt, phi_angles_jax, c2_data_batch, weights_batch, contrasts_jax, offsets_jax, ) ) # L3 adaptive regularization wiring (Sub-PR C2). # When regularization_mode != "none", append a single Tikhonov penalty # row to the residual vector so the active Fourier/individual fit # penalises per-angle scaling variance via the controller's # loss-augmentation callback. if ( anti_degen_controller is not None and getattr(anti_degen_controller, "regularizer", None) is not None and config.regularization_mode != "none" ): callbacks = anti_degen_controller.create_nlsq_callbacks() loss_aug = callbacks.get("loss_augmentation") if loss_aug is not None: _inner_residual_fn = joint_residual_fn def joint_residual_fn_with_penalty(x: np.ndarray) -> np.ndarray: """Residual + appended penalty row from controller's loss_aug.""" base = _inner_residual_fn(x) base_np = np.asarray(base) penalty_value = float(loss_aug(np.asarray(x), base_np)) penalty_row = float(np.sqrt(max(2.0 * penalty_value, 0.0))) return np.concatenate([base_np, [penalty_row]]) joint_residual_fn = joint_residual_fn_with_penalty logger.info( "L3 adaptive regularization active (mode=%s, lambda=%.4e)", config.regularization_mode, config.group_variance_lambda, ) # Run optimization via NLSQAdapter (primary) with NLSQWrapper fallback joint_config = NLSQConfig( method=config.method if config.method != "lm" else "trf", ftol=config.ftol, xtol=config.xtol, gtol=config.gtol, max_nfev=(config.max_nfev * n_phi if config.max_nfev is not None else None), ) joint_result: NLSQResult | None = None joint_param_names = list(varying_names) + [ f"fourier_{i}" for i in range(len(fourier_initial)) ] if HAS_ADAPTERS: try: joint_adapter = NLSQAdapter(parameter_names=joint_param_names) # pyright: ignore[reportPossiblyUnbound] joint_result = joint_adapter.fit( residual_fn=joint_residual_fn, initial_params=x0, bounds=(lb, ub), config=joint_config, ) if not joint_result.success: raise RuntimeError( f"Joint adapter returned success=False: {joint_result.message}" ) except (ValueError, RuntimeError, TypeError) as adapter_exc: logger.warning( "Joint NLSQAdapter failed, falling back to NLSQWrapper: %s", adapter_exc ) joint_result = None if joint_result is None and HAS_WRAPPER: joint_wrapper = NLSQWrapper(parameter_names=joint_param_names) # pyright: ignore[reportPossiblyUnbound] joint_result = joint_wrapper.fit( residual_fn=joint_residual_fn, initial_params=x0, bounds=(lb, ub), config=joint_config, ) if joint_result is None: raise ImportError( "No NLSQ backend available for joint multi-angle fit. " "Ensure heterodyne.optimization.nlsq.adapter is importable." ) # L4 gradient collapse monitor wiring (Sub-PR C3). monitor_summary: dict[str, Any] = {} if ( anti_degen_controller is not None and getattr(anti_degen_controller, "monitor", None) is not None and config.enable_gradient_monitoring ): # TODO(L4-followup): The current adapter loop does not expose a # per-iteration gradient callback to the controller, so # monitor.get_summary() is observation-passive — it returns empty # history. To activate L4 we need either (a) the NLSQAdapter to # call back into the controller via iteration_callback from # create_nlsq_callbacks(), or (b) post-fit gradient-norm sampling # driven by the joint solve's iteration trace. The metadata key # 'gradient_monitor' is currently set only as a placeholder to # surface that L4 was REQUESTED. try: monitor_summary = dict(anti_degen_controller.monitor.get_summary() or {}) except (AttributeError, TypeError) as exc: logger.debug("L4 monitor summary unavailable: %s", exc) monitor_summary = {} if monitor_summary: logger.info("L4 gradient monitor summary: %s", monitor_summary) # Extract results fitted_params_full = joint_result.parameters fitted_physics = fitted_params_full[:n_physics_varying] fitted_fourier = fitted_params_full[n_physics_varying:] fitted_contrast, fitted_offset = fourier.fourier_to_per_angle(fitted_fourier) # Update model with fitted physics parameters full_fitted = param_manager.expand_varying_to_full(fitted_physics) model.set_params(full_fitted) # Update model scaling if len(scaling.contrast) == n_phi: scaling.contrast[:] = fitted_contrast scaling.offset[:] = fitted_offset wall_time = time.perf_counter() - t_start # Build per-angle NLSQResult objects results: list[NLSQResult] = [] for i in range(n_phi): # Compute fitted correlation for this angle fitted_c2 = compute_c2_heterodyne( jnp.asarray(full_fitted), t, q, dt, float(phi_angles[i]), contrast=float(fitted_contrast[i]), offset=float(fitted_offset[i]), ) _residuals_i = np.asarray( compute_residuals( jnp.asarray(full_fitted), t, q, dt, float(phi_angles[i]), c2_data_list[i], weights_list[i], contrast=float(fitted_contrast[i]), offset=float(fitted_offset[i]), ) ) _per_cost_i, _per_chi2_i = _compute_per_angle_chi2( _residuals_i, np.asarray(c2_data_list[i]), n_physics_varying ) result = NLSQResult( parameters=fitted_physics.copy(), parameter_names=list(varying_names), residuals=_residuals_i, final_cost=_per_cost_i, reduced_chi_squared=_per_chi2_i, success=bool(joint_result.success), message=str(joint_result.message), n_function_evals=int(joint_result.n_function_evals or 0), fitted_correlation=np.asarray(fitted_c2), metadata={ "phi_angle": float(phi_angles[i]), "contrast": float(fitted_contrast[i]), "offset": float(fitted_offset[i]), "optimizer": "joint_fourier", "fourier_mode": fourier.config.mode, "fourier_order": fourier.order, "fourier_coeffs": fitted_fourier.tolist(), "fourier_n_coeffs": fourier.n_coeffs, "fourier_reduction": fourier.get_diagnostics()["reduction_ratio"], "n_angles_joint": n_phi, "wall_time_total": wall_time, **( {"hierarchical_config": hierarchical_marker} if hierarchical_marker is not None else {} ), **({"gradient_monitor": monitor_summary} if monitor_summary else {}), }, ) results.append(result) logger.info( "Joint multi-angle fit complete: success=%s, cost=%.6f, " "n_evals=%d, wall_time=%.2fs, %d angles", joint_result.success, joint_result.final_cost or 0.0, joint_result.n_function_evals or 0, wall_time, n_phi, ) return results # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _try_global_optimization( model: HeterodyneModel, c2_data: np.ndarray | jnp.ndarray, phi_angle: float, config: NLSQConfig, weights: np.ndarray | jnp.ndarray | None, use_nlsq_library: bool, ) -> NLSQResult | None: """Attempt CMA-ES or multi-start if configured. Returns the result if a global method was selected, or ``None`` to fall through to local optimization. """ # CMA-ES has highest priority if getattr(config, "enable_cmaes", False): if HAS_CMAES: logger.info("CMA-ES enabled, delegating to fit_nlsq_cmaes") return _fit_cmaes(model, c2_data, phi_angle, config, weights) logger.warning( "[CMA-ES] Enabled in config but not available (cma not installed). " "Install with: uv add cma. " "Falling back to multi-start or local optimization." ) # Multi-start is second priority if getattr(config, "multistart", False): if HAS_MULTISTART: logger.info("Multi-start enabled, delegating to fit_nlsq_multistart") return _fit_multistart( model, c2_data, phi_angle, config, weights, use_nlsq_library, ) logger.warning( "Multi-start enabled in config but multistart module not available. " "Falling back to local optimization." ) return None def _fit_cmaes( model: HeterodyneModel, c2_data: np.ndarray | jnp.ndarray, phi_angle: float, config: NLSQConfig, weights: np.ndarray | jnp.ndarray | None, ) -> NLSQResult: """Run CMA-ES global optimization with NLSQ warm-start and two-phase comparison. Implements fixes #1, #5, #6, #7 from homodyne parity: - **Phase 1 (Fix #1)**: Run local NLSQ refinement to get a warm-start point. - **Phase 2**: Run CMA-ES using the NLSQ result as initial guess. - **Phase 3 (Fix #7)**: Compare NLSQ vs CMA-ES by reduced chi-squared, keep the better result. - **Fix #5**: Classify result quality as good/marginal/poor. - **Fix #6**: Optionally apply anti-degeneracy penalty weights. """ from heterodyne.optimization.nlsq.cmaes_wrapper import CMAESConfig from heterodyne.optimization.nlsq.validation.fit_quality import classify_fit_quality param_manager = model.param_manager initial_varying = param_manager.get_initial_values() lower_bounds, upper_bounds = param_manager.get_bounds() initial_varying = np.clip(initial_varying, lower_bounds, upper_bounds) c2_jax = jnp.asarray(c2_data, dtype=jnp.float64) weights_jax = ( jnp.asarray(weights, dtype=jnp.float64) if weights is not None else None ) t, q, dt = model.t, model.q, model.dt n_data = c2_jax.size contrast_val, offset_val = model.scaling.get_for_angle(0) def objective_fn(varying_params: np.ndarray) -> float: full_params = np.array(param_manager.get_full_values()) for i, idx in enumerate(param_manager.varying_indices): full_params[idx] = varying_params[i] residuals = compute_residuals( jnp.asarray(full_params), t, q, dt, phi_angle, c2_jax, weights_jax, contrast_val, offset_val, ) return float(0.5 * jnp.sum(residuals**2)) residual_fn = _make_numpy_residual_fn( model, c2_data, phi_angle, weights, contrast_val, offset_val ) # ------------------------------------------------------------------ # Phase 1 (Fix #1): NLSQ warm-start # ------------------------------------------------------------------ nlsq_result: NLSQResult | None = None cmaes_x0 = initial_varying try: logger.info("CMA-ES Phase 1: NLSQ warm-start refinement") nlsq_result = _fit_local( model, c2_data, phi_angle, config, weights, use_nlsq_library=config.use_nlsq_library, ) if nlsq_result.success: cmaes_x0 = nlsq_result.parameters.copy() logger.info( "NLSQ warm-start succeeded: cost=%.6e, chi2_red=%.4f", nlsq_result.final_cost or float("inf"), nlsq_result.reduced_chi_squared or float("inf"), ) else: logger.warning( "NLSQ warm-start failed (%s), using raw initial params for CMA-ES", nlsq_result.message, ) except (ValueError, RuntimeError, ImportError) as e: logger.warning( "NLSQ warm-start raised %s: %s — proceeding with raw p0", type(e).__name__, e, ) # Ensure model parameters are reset for CMA-ES (NLSQ may have modified them) model.set_params(param_manager.expand_varying_to_full(initial_varying)) # ------------------------------------------------------------------ # Phase 2: CMA-ES global optimization # ------------------------------------------------------------------ logger.info("CMA-ES Phase 2: global search (warm-started)") cmaes_config = CMAESConfig( sigma0=config.cmaes_sigma0, popsize=config.cmaes_population_size, maxiter=config.cmaes_max_iterations, tolx=config.cmaes_tolx, tolfun=config.cmaes_tolfun, diagonal_filtering=getattr(config, "cmaes_diagonal_filtering", "none"), ) cmaes_result = fit_with_cmaes( # pyright: ignore[reportPossiblyUnbound] objective_fn=objective_fn, initial_params=cmaes_x0, bounds=(lower_bounds, upper_bounds), parameter_names=param_manager.varying_names, config=cmaes_config, residual_fn=residual_fn, n_data=n_data, anti_degeneracy=getattr(config, "cmaes_anti_degeneracy", False), ) # ------------------------------------------------------------------ # Phase 3 (Fix #7): Compare NLSQ vs CMA-ES, keep the better result # ------------------------------------------------------------------ nlsq_cost = ( float(nlsq_result.final_cost) if (nlsq_result and nlsq_result.success and nlsq_result.final_cost is not None) else float("inf") ) cmaes_cost = ( float(cmaes_result.final_cost) if (cmaes_result.success and cmaes_result.final_cost is not None) else float("inf") ) if nlsq_cost <= cmaes_cost and nlsq_result is not None and nlsq_result.success: result = nlsq_result winner = "nlsq" logger.info( "Phase 3: NLSQ wins (cost=%.6e vs CMA-ES=%.6e)", nlsq_cost, cmaes_cost, ) else: result = cmaes_result winner = "cmaes" logger.info( "Phase 3: CMA-ES wins (cost=%.6e vs NLSQ=%.6e)", cmaes_cost, nlsq_cost, ) # ------------------------------------------------------------------ # Post-fit: update model, classify quality (Fix #5) # ------------------------------------------------------------------ if result.success: full_fitted = param_manager.expand_varying_to_full(result.parameters) fitted_c2 = compute_c2_heterodyne( jnp.asarray(full_fitted), t, q, dt, phi_angle, contrast_val, offset_val ) result.fitted_correlation = np.asarray(fitted_c2) model.set_params(full_fitted) # Apply same chi2 correction as _fit_local (DOF + σ² normalization) if result.final_cost is not None: n_matrix = c2_jax.shape[0] n_valid = (n_matrix - 1) * (n_matrix - 2) n_dof_valid = max(n_valid - len(param_manager.varying_names), 1) c2_np = np.asarray(c2_jax) row_idx = np.arange(n_matrix) lag_mat = np.abs(row_idx[:, None] - row_idx[None, :]) far_vals = c2_np[lag_mat >= n_matrix // 2] sigma2_noise = float(np.var(far_vals)) if far_vals.size > 1 else 0.0 if sigma2_noise > 1e-12: ssr = 2.0 * result.final_cost result.reduced_chi_squared = ssr / (sigma2_noise * n_dof_valid) quality_flag = classify_fit_quality(result.reduced_chi_squared) result.metadata["optimizer"] = "cmaes" result.metadata["cmaes_winner"] = winner result.metadata["cmaes_cost"] = cmaes_cost result.metadata["nlsq_warmstart_cost"] = nlsq_cost result.metadata["quality_flag"] = quality_flag _log_result(result) return result def _fit_multistart( model: HeterodyneModel, c2_data: np.ndarray | jnp.ndarray, phi_angle: float, config: NLSQConfig, weights: np.ndarray | jnp.ndarray | None, use_nlsq_library: bool, ) -> NLSQResult: """Run multi-start optimization.""" param_manager = model.param_manager varying_names = param_manager.varying_names initial_varying = param_manager.get_initial_values() lower_bounds, upper_bounds = param_manager.get_bounds() initial_varying = np.clip(initial_varying, lower_bounds, upper_bounds) contrast_val, offset_val = model.scaling.get_for_angle(0) # Build residual function residual_fn = _make_numpy_residual_fn( model, c2_data, phi_angle, weights, contrast_val, offset_val ) # Select adapter adapter = _select_adapter(varying_names, use_nlsq_library) # Build multistart config ms_config = MultiStartConfig( # pyright: ignore[reportPossiblyUnbound] n_starts=getattr(config, "multistart_n", 10), seed=getattr(config, "multistart_seed", None), ) optimizer = MultiStartOptimizer(adapter=adapter, config=ms_config) # pyright: ignore[reportPossiblyUnbound] multi_result = optimizer.fit( residual_fn=residual_fn, initial_params=initial_varying, bounds=(lower_bounds, upper_bounds), config=config, ) result = multi_result.to_nlsq_result() if result.success: full_fitted = param_manager.expand_varying_to_full(result.parameters) fitted_c2 = compute_c2_heterodyne( jnp.asarray(full_fitted), model.t, model.q, model.dt, phi_angle, contrast_val, offset_val, ) result.fitted_correlation = np.asarray(fitted_c2) model.set_params(full_fitted) result.metadata["optimizer"] = "multistart" _log_result(result) return result def _fit_local( model: HeterodyneModel, c2_data: np.ndarray | jnp.ndarray, phi_angle: float, config: NLSQConfig, weights: np.ndarray | jnp.ndarray | None, use_nlsq_library: bool, ) -> NLSQResult: """Run local (single-start) optimization with adapter/wrapper fallback. Tries adapter first; on failure falls back to wrapper with progressive recovery. """ t_start = time.perf_counter() param_manager = model.param_manager varying_names = param_manager.varying_names n_varying = param_manager.n_varying logger.info("Fitting %d parameters: %s", n_varying, varying_names) # Memory-aware strategy selection if HAS_MEMORY: n_data_est = np.asarray(c2_data).size decision = select_nlsq_strategy(n_data_est, n_varying) # pyright: ignore[reportPossiblyUnbound] if decision.strategy in (NLSQStrategy.LARGE, NLSQStrategy.STREAMING): # pyright: ignore[reportPossiblyUnbound] logger.warning( "Estimated peak memory (%.2f GB) exceeds threshold (%.2f GB). " "Fit may fail with OOM.", decision.peak_memory_gb, decision.threshold_gb, ) # Get initial values and bounds initial_varying = param_manager.get_initial_values() lower_bounds, upper_bounds = param_manager.get_bounds() initial_varying = np.clip(initial_varying, lower_bounds, upper_bounds) # Convert data to JAX arrays c2_jax = jnp.asarray(c2_data, dtype=jnp.float64) weights_jax = ( jnp.asarray(weights, dtype=jnp.float64) if weights is not None else None ) if weights_jax is not None and weights_jax.shape != c2_jax.shape: raise ValueError( f"Weights shape {weights_jax.shape} does not match data shape {c2_jax.shape}" ) # Capture constants fixed_values = jnp.asarray(param_manager.get_full_values(), dtype=jnp.float64) varying_indices = jnp.array(param_manager.varying_indices) n_data = c2_jax.size t, q, dt = model.t, model.q, model.dt # Per-angle scaling — fixed during local optimization (constant mode parity) contrast_val, offset_val = model.scaling.get_for_angle(0) # Build residual functions def jax_residual_fn(x: jnp.ndarray, *varying_params: float) -> jnp.ndarray: """Pure JAX residual function for nlsq tracing.""" varying_array = jnp.array(varying_params, dtype=jnp.float64) full_params = fixed_values.at[varying_indices].set(varying_array) return compute_residuals( full_params, t, q, dt, phi_angle, c2_jax, weights_jax, contrast_val, offset_val, ) numpy_residual_fn = _make_numpy_residual_fn( model, c2_data, phi_angle, weights, contrast_val, offset_val ) # ------------------------------------------------------------------ # Adapter → wrapper fallback chain # ------------------------------------------------------------------ adapter_error: Exception | None = None fallback_occurred = False result: NLSQResult | None = None if use_nlsq_library and HAS_ADAPTERS: try: adapter = NLSQAdapter(parameter_names=varying_names) # pyright: ignore[reportPossiblyUnbound] logger.debug("Using NLSQAdapter (CurveFit class) for optimization") logger.debug("Attempting optimization with NLSQAdapter") result = adapter.fit_jax( jax_residual_fn=jax_residual_fn, initial_params=initial_varying, bounds=(lower_bounds, upper_bounds), config=config, n_data=n_data, ) if result.success: logger.info("NLSQAdapter optimization succeeded") else: raise RuntimeError(f"Adapter returned success=False: {result.message}") except (ValueError, RuntimeError, TypeError, ImportError, OSError) as e: adapter_error = e logger.warning("NLSQAdapter failed, falling back to wrapper: %s", e) fallback_occurred = True result = None # Wrapper fallback (or primary if use_nlsq_library=False) if result is None and HAS_WRAPPER: try: wrapper = NLSQWrapper(parameter_names=varying_names) # pyright: ignore[reportPossiblyUnbound] logger.debug("Attempting optimization with NLSQWrapper") result = wrapper.fit( residual_fn=numpy_residual_fn, initial_params=initial_varying, bounds=(lower_bounds, upper_bounds), config=config, ) if fallback_occurred: logger.info("NLSQWrapper fallback optimization succeeded") else: logger.info("NLSQWrapper optimization succeeded") except (ValueError, RuntimeError, TypeError, MemoryError) as wrapper_error: logger.error( "Both NLSQAdapter and NLSQWrapper failed: adapter=%s, wrapper=%s", adapter_error, wrapper_error, ) result = NLSQResult( parameters=initial_varying, parameter_names=varying_names, success=False, message=f"All optimizers failed. Adapter: {adapter_error}; " f"Wrapper: {wrapper_error}", ) if result is None: raise ImportError( "No NLSQ optimization backend available. " "Ensure heterodyne.optimization.nlsq.adapter is importable." ) # ------------------------------------------------------------------ # Post-fit: compute fitted correlation, update model # ------------------------------------------------------------------ if result.success: full_fitted = param_manager.expand_varying_to_full(result.parameters) fitted_c2 = compute_c2_heterodyne( jnp.asarray(full_fitted), t, q, dt, phi_angle, contrast_val, offset_val, ) result.fitted_correlation = np.asarray(fitted_c2) model.set_params(full_fitted) # ------------------------------------------------------------------ # Post-fit: correct reduced chi-squared # # The raw chi2 from adapter.fit_jax is SSR / (N² − n_params), where # SSR = Σ r² over the full N×N residual vector. Two corrections: # # 1. DOF: the N diagonal residuals are forced to 0 by the # non_diagonal mask in compute_residuals — they should be # excluded from the degrees-of-freedom count. # n_valid = N*(N−1) instead of N². # # 2. σ² normalization: without dividing by measurement noise, # chi2 = MSE ≪ 1 for normalized C2 data (C2 ~ 1, residuals ~ 5%). # We estimate σ²_noise from the far-lag plateau of the C2 matrix # (|t2−t1| ≥ N/2), where correlations have fully decayed and # the remaining variance is photon-counting noise. # # chi2_corrected = SSR / (σ²_noise × n_dof_valid) → ~1 for good fits # ------------------------------------------------------------------ if result.final_cost is not None: n_matrix = c2_jax.shape[0] n_valid = (n_matrix - 1) * (n_matrix - 2) # exclude diagonal + t=0 boundary n_dof_valid = max(n_valid - n_varying, 1) c2_np = np.asarray(c2_jax) row_idx = np.arange(n_matrix) lag_mat = np.abs(row_idx[:, None] - row_idx[None, :]) far_mask = lag_mat >= n_matrix // 2 # diagonal (lag=0) not included far_vals = c2_np[far_mask] sigma2_noise = float(np.var(far_vals)) if far_vals.size > 1 else 0.0 if sigma2_noise > 1e-12: ssr = 2.0 * result.final_cost chi2_corrected = ssr / (sigma2_noise * n_dof_valid) logger.debug( "chi2 correction: σ²_noise=%.4e n_valid=%d SSR=%.4e " "raw_chi2=%.4g → chi2_corrected=%.4f", sigma2_noise, n_valid, ssr, result.reduced_chi_squared or float("nan"), chi2_corrected, ) result.reduced_chi_squared = chi2_corrected else: logger.warning( "chi2 noise estimate near-zero (σ²=%.2e); " "reporting uncorrected MSE chi2", sigma2_noise, ) result.metadata["fallback_occurred"] = fallback_occurred if adapter_error is not None: result.metadata["adapter_error"] = str(adapter_error) result.metadata["optimizer"] = "local" result.metadata["wall_time_total"] = time.perf_counter() - t_start _log_result(result) return result def _make_numpy_residual_fn( model: HeterodyneModel, c2_data: np.ndarray | jnp.ndarray, phi_angle: float, weights: np.ndarray | jnp.ndarray | None, contrast: float = 1.0, offset: float = 1.0, ) -> Any: """Create a numpy residual function closed over model/data. Returns a callable ``(varying_params: np.ndarray) -> np.ndarray``. Hot-path optimisation: ``fixed_values`` and ``varying_indices`` are pre-captured as JAX device arrays at construction time so each call only performs a single ``jnp.asarray`` (for the incoming numpy vector) and one ``jnp.ndarray.at[].set()`` scatter instead of a Python loop plus a full host copy. """ param_manager = model.param_manager c2_jax = jnp.asarray(c2_data, dtype=jnp.float64) weights_jax = ( jnp.asarray(weights, dtype=jnp.float64) if weights is not None else None ) t, q, dt = model.t, model.q, model.dt # Pre-capture as JAX device arrays — allocated once, reused every call. # NOTE: fixed_values snapshot is taken at construction time. Do not mutate # param_manager between construction and optimizer completion. fixed_values = jnp.asarray(param_manager.get_full_values(), dtype=jnp.float64) varying_indices = jnp.array(param_manager.varying_indices, dtype=jnp.int32) def residual_fn(varying_params: np.ndarray) -> np.ndarray: varying_jax = jnp.asarray(varying_params, dtype=jnp.float64) full_params = fixed_values.at[varying_indices].set(varying_jax) residuals = compute_residuals( full_params, t, q, dt, phi_angle, c2_jax, weights_jax, contrast, offset, ) return np.asarray(residuals) return residual_fn def _select_adapter( varying_names: list[str], use_nlsq_library: bool, ) -> Any: """Select the appropriate adapter backend. Returns NLSQAdapter when the nlsq library is available and requested, otherwise falls back to NLSQWrapper (memory-tier routing). """ if use_nlsq_library and HAS_ADAPTERS: try: return NLSQAdapter(parameter_names=varying_names) # pyright: ignore[reportPossiblyUnbound] except ImportError: logger.warning("nlsq library not available, falling back to NLSQWrapper") if HAS_WRAPPER: return NLSQWrapper(parameter_names=varying_names) # pyright: ignore[reportPossiblyUnbound] raise ImportError("No NLSQ adapter available") def _log_result(result: NLSQResult) -> None: """Log optimization results summary.""" logger.info("=" * 60) logger.info("NLSQ OPTIMIZATION COMPLETE") logger.info("=" * 60) status = "SUCCESS" if result.success else "FAILED" logger.info("Status: %s", status) logger.info("Message: %s", result.message) if result.final_cost is not None: logger.info("Final cost: %.6e", result.final_cost) if result.reduced_chi_squared is not None: logger.info("Reduced χ²: %.4f", result.reduced_chi_squared) if result.wall_time_seconds is not None: logger.info("Wall time: %.2f s", result.wall_time_seconds) if result.success: n_params = len(result.parameters) logger.info("Fitted parameters:") logger.info(" Physical parameters:") for name, val in zip(result.parameter_names, result.parameters, strict=True): unc_val = result.get_uncertainty(name) if unc_val is not None: logger.info(" %s: %.6g ± %.3g", name, val, unc_val) else: logger.info(" %s: %.6g", name, val) logger.info( " Total parameters: %d physical + 2 scaling = %d", n_params, n_params + 2, ) logger.info("=" * 60)