"""NLSQ adapters: NLSQAdapter (JAX-traced) and NLSQWrapper (memory-aware fallback).
Import order: nlsq imports appear before JAX so that nlsq can configure
JAX x64 mode before JAX is initialised.
"""
from __future__ import annotations
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import numpy as np
# ---------------------------------------------------------------------------
# nlsq imports — MUST precede any JAX import so nlsq can set x64 mode
# ---------------------------------------------------------------------------
from nlsq import CurveFit, curve_fit, curve_fit_large # noqa: E402
try:
from nlsq import AdaptiveHybridStreamingOptimizer, HybridStreamingConfig
STREAMING_AVAILABLE = True
except ImportError:
STREAMING_AVAILABLE = False
AdaptiveHybridStreamingOptimizer = None # type: ignore[assignment,misc]
HybridStreamingConfig = None # type: ignore[assignment,misc]
import jax.numpy as jnp # noqa: E402 — must follow nlsq to preserve x64 init order
from heterodyne.optimization.nlsq.adapter_base import NLSQAdapterBase
from heterodyne.optimization.nlsq.config import NLSQConfig
from heterodyne.optimization.nlsq.memory import NLSQStrategy, select_nlsq_strategy
from heterodyne.optimization.nlsq.result_builder import (
build_failed_result,
build_result_from_nlsq,
)
from heterodyne.optimization.nlsq.results import NLSQResult
from heterodyne.utils.logging import get_logger
if TYPE_CHECKING:
pass
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Model cache — avoids re-JIT-compiling CurveFit for identical problem shapes
# ---------------------------------------------------------------------------
_MODEL_CACHE_MAX_SIZE = 64
def _optimizer_kwargs(config: NLSQConfig, method: str) -> dict:
"""Build the kwargs dict passed to every nlsq CurveFit.curve_fit call.
Centralises all config→optimizer parameter mapping so every call site
stays in sync. ``max_nfev=None`` (unlimited) is omitted so nlsq keeps
its own default; an explicit int is passed through.
"""
kw: dict = {
"method": method,
"loss": config.loss,
"ftol": config.ftol,
"xtol": config.xtol,
"gtol": config.gtol,
"x_scale": config.x_scale,
}
if config.max_nfev is not None:
kw["max_nfev"] = config.max_nfev
return kw
[docs]
@dataclass(frozen=True)
class ModelCacheKey:
"""Cache key for CurveFit instances.
Includes phi_angles and scaling_mode so that different multi-angle
or scaling configurations do not share the same compiled fitter.
Homodyne-parity fields (analysis_mode, q, per_angle_scaling) are present
but default to heterodyne-appropriate values so existing callers that use
the fitter-centric path (n_data, n_params, scaling_mode) continue to work.
"""
n_data: int
n_params: int
phi_angles: tuple[float, ...] | None
scaling_mode: str
callable_scope: object | None = None
# Parity fields (homodyne ModelCacheKey)
analysis_mode: str = "full"
q: float = 0.0
per_angle_scaling: bool = True
[docs]
@dataclass
class CachedModel:
"""A cached CurveFit instance with usage stats.
The ``model`` and ``model_func`` fields mirror homodyne's CachedModel
for API parity; in the heterodyne fitter-centric path they remain None
and ``fitter`` carries the nlsq.CurveFit instance.
"""
fitter: object # nlsq.CurveFit
created_at: float = field(default_factory=time.monotonic)
last_accessed: float = field(default_factory=time.monotonic)
n_hits: int = 0
# Parity fields (homodyne CachedModel)
model: Any = None
model_func: Callable[..., Any] | None = None
_model_cache: dict[ModelCacheKey, CachedModel] = {}
_cache_stats: dict[str, int] = {"hits": 0, "misses": 0}
[docs]
def get_or_create_fitter(
n_data: int,
n_params: int,
phi_angles: tuple[float, ...] | None = None,
scaling_mode: str = "auto",
callable_scope: object | None = None,
) -> tuple[object, bool]:
"""Get a CurveFit instance from cache or create a new one.
Args:
n_data: Number of data points (flength).
n_params: Number of parameters.
phi_angles: Tuple of azimuthal angles (distinguishes multi-angle configs).
scaling_mode: Contrast/offset scaling mode (e.g. "auto", "individual").
callable_scope: Optional residual/model callable that must not share a
stateful fitter with different residual closures.
Returns:
Tuple of (CurveFit fitter, cache_hit: bool).
"""
key = ModelCacheKey(
n_data=n_data,
n_params=n_params,
phi_angles=phi_angles,
scaling_mode=scaling_mode,
callable_scope=callable_scope,
)
if key in _model_cache:
_model_cache[key].last_accessed = time.monotonic()
_model_cache[key].n_hits += 1
_cache_stats["hits"] += 1
logger.debug(
"CurveFit cache hit: n_data=%d n_params=%d phi=%s scaling=%s",
n_data,
n_params,
phi_angles,
scaling_mode,
)
return _model_cache[key].fitter, True
_cache_stats["misses"] += 1
# Evict oldest entry if cache is full
if len(_model_cache) >= _MODEL_CACHE_MAX_SIZE:
oldest_key = min(_model_cache, key=lambda k: _model_cache[k].last_accessed)
logger.debug("CurveFit cache eviction: removing oldest entry %s", oldest_key)
del _model_cache[oldest_key]
fitter = CurveFit(flength=int(n_data))
_model_cache[key] = CachedModel(fitter=fitter)
return fitter, False
[docs]
def get_or_create_model(
analysis_mode: str,
phi_angles: np.ndarray,
q: float,
per_angle_scaling: bool = True,
config: dict[str, Any] | None = None,
enable_jit: bool = True,
) -> tuple[Any, Callable[..., Any] | None, bool]:
"""Get cached model or create a new placeholder for heterodyne.
This function provides model instance caching for API parity with
homodyne's ``get_or_create_model()``. Heterodyne's NLSQ path uses a
residual-function interface rather than a high-level model object, so
``model`` and ``model_func`` are returned as ``None`` here — callers
that need a concrete fitter should use ``get_or_create_fitter()``.
Args:
analysis_mode: Physics mode string (heterodyne always uses ``'full'``).
phi_angles: Unique phi angles in radians.
q: Scattering wavevector magnitude.
per_angle_scaling: Whether per-angle contrast/offset is used.
config: Optional config dict (unused; kept for API parity).
enable_jit: Whether JIT compilation is requested (advisory).
Returns:
Tuple of (model, model_func, cache_hit) where model and model_func
are always ``None`` in heterodyne (residual path), and cache_hit
reflects whether the key was already registered.
"""
if len(phi_angles) == 0:
raise ValueError("phi_angles cannot be empty")
if q < 0:
raise ValueError(f"q must be non-negative, got {q}")
normalized_mode = analysis_mode if analysis_mode else "full"
phi_sorted = tuple(float(v) for v in sorted(set(phi_angles)))
q_rounded = round(float(q), 10)
key = ModelCacheKey(
n_data=0,
n_params=0,
phi_angles=phi_sorted,
scaling_mode="auto",
callable_scope=None,
analysis_mode=normalized_mode,
q=q_rounded,
per_angle_scaling=per_angle_scaling,
)
if key in _model_cache:
_model_cache[key].last_accessed = time.monotonic()
_model_cache[key].n_hits += 1
_cache_stats["hits"] += 1
logger.debug(
"Model cache hit: mode=%s, n_phi=%d, q=%.6g, hits=%d",
normalized_mode,
len(phi_angles),
q,
_model_cache[key].n_hits,
)
cached = _model_cache[key]
return cached.model, cached.model_func, True
_cache_stats["misses"] += 1
logger.debug(
"Model cache miss: mode=%s, n_phi=%d, q=%.6g",
normalized_mode,
len(phi_angles),
q,
)
if enable_jit:
logger.debug("JIT flag enabled; actual JIT applied by underlying model or NLSQ")
if len(_model_cache) >= _MODEL_CACHE_MAX_SIZE:
oldest_key = min(_model_cache, key=lambda k: _model_cache[k].last_accessed)
logger.debug("LRU eviction: removed oldest cached model")
del _model_cache[oldest_key]
_model_cache[key] = CachedModel(fitter=None, model=None, model_func=None)
return None, None, False
[docs]
def clear_model_cache() -> int:
"""Clear the CurveFit model cache and reset hit/miss counters.
Returns:
Number of models removed from the cache.
"""
n_cleared = len(_model_cache)
_model_cache.clear()
_cache_stats["hits"] = 0
_cache_stats["misses"] = 0
logger.info("Cleared model cache: %d models removed", n_cleared)
return n_cleared
[docs]
def get_cache_stats() -> dict[str, int]:
"""Return cache hit/miss/size statistics."""
return {**_cache_stats, "size": len(_model_cache)}
# ---------------------------------------------------------------------------
# Shared convergence assessment
# ---------------------------------------------------------------------------
def _assess_convergence(
fitted_params: np.ndarray,
initial_params: np.ndarray,
reduced_chi2: float | None,
) -> tuple[bool, str, str]:
"""Apply post-fit convergence heuristics.
Returns:
(success, message, convergence_reason)
"""
if not np.all(np.isfinite(fitted_params)):
return False, "Non-finite parameters in result", "failed"
if reduced_chi2 is not None and reduced_chi2 > 1e6:
return (
False,
f"Poor fit quality (reduced chi-squared = {reduced_chi2:.2e})",
"poor_fit",
)
if np.allclose(fitted_params, initial_params, rtol=1e-12, atol=1e-12):
return False, "Optimizer made no progress from initial values", "no_progress"
return True, "Optimization converged", "tolerance"
# ---------------------------------------------------------------------------
# AdapterConfig — configuration for NLSQAdapter (homodyne parity)
# ---------------------------------------------------------------------------
[docs]
@dataclass
class AdapterConfig:
"""Configuration for NLSQAdapter.
Attributes:
enable_cache: Enable model instance caching.
enable_jit: Enable JIT compilation of model functions.
enable_recovery: Enable NLSQ's built-in recovery system.
enable_stability: Enable NLSQ's numerical stability guard.
goal: Optimization goal (fast, robust, quality, memory_efficient).
workflow: Workflow tier override (auto, standard, streaming).
"""
enable_cache: bool = True
enable_jit: bool = True
enable_recovery: bool = True
enable_stability: bool = True
goal: str = "quality" # XPCS requires precision
workflow: str = "auto"
# ---------------------------------------------------------------------------
# NLSQAdapter — primary JAX-traced adapter
# ---------------------------------------------------------------------------
[docs]
class NLSQAdapter(NLSQAdapterBase):
"""Adapter for the nlsq library's CurveFit optimizer.
Uses JAX-accelerated nonlinear least squares from the nlsq package.
The ``fit()`` method calls ``nlsq.CurveFit`` directly — no scipy delegation.
For pure-JAX residual functions, prefer ``fit_jax()`` which passes a
JAX-traceable function to ``CurveFit.curve_fit()``.
"""
[docs]
def __init__(
self,
parameter_names: list[str] | None = None,
config: AdapterConfig | None = None,
) -> None:
"""Initialise the adapter.
Args:
parameter_names: Names of parameters being optimised, in order.
Kept as the primary heterodyne argument. Defaults to an empty
list so that ``NLSQAdapter()`` (homodyne-style, no names) works.
config: Optional AdapterConfig for feature flags (parity with
homodyne). When provided, ``enable_recovery`` and
``enable_stability`` are forwarded to the underlying CurveFit
constructor if the installed nlsq version supports them.
"""
self._parameter_names = parameter_names if parameter_names is not None else []
self.config = config or AdapterConfig()
logger.debug(
"NLSQAdapter initialized: cache=%s, recovery=%s, stability=%s, goal=%s",
self.config.enable_cache,
self.config.enable_recovery,
self.config.enable_stability,
self.config.goal,
)
@property
def name(self) -> str:
return "nlsq.CurveFit"
[docs]
def supports_bounds(self) -> bool:
return True
[docs]
def supports_jacobian(self) -> bool:
return True
[docs]
def is_available(self) -> bool:
"""Check if the NLSQ CurveFit backend is available."""
try:
from nlsq import CurveFit as _CurveFit # noqa: F401
return True
except ImportError:
return False
@property
def workflow_available(self) -> bool:
"""Check if NLSQ WorkflowSelector is available.
WorkflowSelector was removed in NLSQ v0.6.0; heterodyne uses
its own ``select_nlsq_strategy()`` from ``memory.py`` instead.
Always returns False for parity with homodyne post-v0.6.0.
"""
return False
[docs]
def fit(
self,
residual_fn: Callable[[np.ndarray], np.ndarray],
initial_params: np.ndarray,
bounds: tuple[np.ndarray, np.ndarray],
config: NLSQConfig,
jacobian_fn: Callable[[np.ndarray], np.ndarray] | None = None,
*,
analysis_mode: str = "full",
per_angle_scaling: bool = True,
diagnostics_enabled: bool = False,
per_angle_scaling_initial: dict[str, list[float]] | None = None,
anti_degeneracy_controller: Any | None = None,
) -> NLSQResult:
"""Run NLSQ optimisation using nlsq.CurveFit.
Wraps the residual function into the ``(xdata, *params)`` signature
expected by ``CurveFit.curve_fit`` and normalises the result via
``build_result_from_nlsq``.
Args:
residual_fn: Callable ``(params: ndarray) -> residuals: ndarray``.
initial_params: Starting parameter values.
bounds: ``(lower, upper)`` bound arrays.
config: Optimisation configuration.
jacobian_fn: Optional analytic Jacobian (unused by CurveFit; kept
for API compatibility).
analysis_mode: Physics mode string (heterodyne always uses
``'full'``). Present for homodyne API parity; not used
internally because heterodyne residuals are pre-computed.
per_angle_scaling: Whether per-angle contrast/offset is used.
Present for homodyne API parity; heterodyne encodes scaling
inside ``residual_fn``.
diagnostics_enabled: Enable extended diagnostics logging.
per_angle_scaling_initial: Initial per-angle contrast/offset.
Present for homodyne API parity; not used in residual path.
anti_degeneracy_controller: Anti-degeneracy controller. When
provided and it exposes ``create_nlsq_callbacks()``, the
returned callbacks are injected into the optimizer call.
Returns:
NLSQResult with fit results.
"""
start_time = time.perf_counter()
lower_bounds, upper_bounds = bounds
initial_params = np.clip(initial_params, lower_bounds, upper_bounds)
n_params = len(initial_params)
logger.info("NLSQAdapter.fit: %d parameters", n_params)
try:
# Probe residual length
probe = residual_fn(initial_params)
n_data = len(np.asarray(probe))
# Create xdata/ydata for CurveFit API (target = zero residuals)
xdata = np.arange(n_data, dtype=np.float64)
ydata = np.zeros(n_data, dtype=np.float64)
# Wrap residual_fn into (xdata, *params) signature.
# jnp.array is required: nlsq 0.6.12 calls func(xdata, *args) inside
# @jit, so *params are traced JAX scalars — np.array would raise
# TracerArrayConversionError.
def _wrapped(x: np.ndarray, *params: Any) -> Any:
# jnp.Array satisfies the ndarray protocol at runtime; ignore static mismatch.
return residual_fn(jnp.array(params, dtype=jnp.float64)) # type: ignore[arg-type]
fitter, cache_hit = get_or_create_fitter(
n_data=n_data,
n_params=n_params,
phi_angles=None,
scaling_mode="auto",
callable_scope=residual_fn,
)
if cache_hit:
logger.debug("CurveFit cache hit for shape (%d, %d)", n_data, n_params)
# Resolve method — dogbox is not supported by CurveFit
method = config.method
if method == "dogbox":
logger.warning("Method 'dogbox' not supported by CurveFit; using 'trf'")
method = "trf"
logger.info(
"NLSQAdapter settings: method=%s loss=%s gtol=%.2e "
"max_nfev=%s x_scale=%s",
method,
config.loss,
config.gtol,
config.max_nfev
if config.max_nfev is not None
else f"auto({100 * n_params})",
config.x_scale,
)
optimizer_kw = _optimizer_kwargs(config, method)
# Inject anti-degeneracy callbacks if controller provides them
if anti_degeneracy_controller is not None:
if hasattr(anti_degeneracy_controller, "create_nlsq_callbacks"):
callbacks = anti_degeneracy_controller.create_nlsq_callbacks()
if callbacks:
optimizer_kw.update(callbacks)
logger.debug(
"Injected anti-degeneracy callbacks: %s",
list(callbacks.keys()),
)
nlsq_result = fitter.curve_fit( # type: ignore[union-attr]
f=_wrapped,
xdata=xdata,
ydata=ydata,
p0=initial_params,
bounds=(lower_bounds, upper_bounds),
**optimizer_kw,
)
wall_time = time.perf_counter() - start_time
# Normalise via result_builder (handles tuple / object / dict formats)
result = build_result_from_nlsq(
nlsq_result=nlsq_result,
parameter_names=self._parameter_names,
n_data=n_data,
wall_time=wall_time,
)
# Apply convergence heuristics on top of build_result_from_nlsq
success, message, reason = _assess_convergence(
fitted_params=result.parameters,
initial_params=initial_params,
reduced_chi2=result.reduced_chi_squared,
)
if not success:
logger.warning("NLSQAdapter convergence check failed: %s", message)
# Return a corrected result with success=False
return NLSQResult(
parameters=result.parameters,
parameter_names=self._parameter_names,
success=False,
message=message,
uncertainties=result.uncertainties,
covariance=result.covariance,
final_cost=result.final_cost,
reduced_chi_squared=result.reduced_chi_squared,
n_iterations=result.n_iterations,
n_function_evals=result.n_function_evals,
convergence_reason=reason,
residuals=result.residuals,
jacobian=result.jacobian,
wall_time_seconds=wall_time,
metadata=result.metadata,
)
return result
except (RuntimeError, ValueError, TypeError) as exc:
logger.error("NLSQAdapter.fit failed: %s", exc)
wall_time = time.perf_counter() - start_time
return build_failed_result(
parameter_names=self._parameter_names,
message=str(exc),
initial_params=initial_params,
wall_time=wall_time,
)
[docs]
def fit_jax(
self,
jax_residual_fn: Callable[..., Any],
initial_params: np.ndarray,
bounds: tuple[np.ndarray, np.ndarray],
config: NLSQConfig,
n_data: int,
) -> NLSQResult:
"""Run NLSQ optimisation using a pure JAX-traceable residual function.
This method accepts a function with the signature
``(xdata, *params) -> residuals`` that nlsq can trace through JAX.
Args:
jax_residual_fn: JAX-compatible callable ``(x, *params) -> residuals``.
initial_params: Starting parameter values.
bounds: ``(lower, upper)`` bound arrays.
config: Optimisation configuration.
n_data: Number of data points (used as CurveFit ``flength``).
Returns:
NLSQResult with fit results.
"""
start_time = time.perf_counter()
lower_bounds, upper_bounds = bounds
initial_params = np.clip(initial_params, lower_bounds, upper_bounds)
n_params = len(initial_params)
logger.info(
"NLSQAdapter.fit_jax: %d parameters, %d data points", n_params, n_data
)
try:
xdata = np.arange(n_data, dtype=np.float64)
ydata = np.zeros(n_data, dtype=np.float64)
fitter, cache_hit = get_or_create_fitter(
n_data=n_data,
n_params=n_params,
phi_angles=None,
scaling_mode="auto",
callable_scope=jax_residual_fn,
)
if cache_hit:
logger.debug("CurveFit cache hit for shape (%d, %d)", n_data, n_params)
method = config.method
if method == "dogbox":
logger.warning("Method 'dogbox' not supported by CurveFit; using 'trf'")
method = "trf"
logger.info(
"NLSQAdapter.fit_jax settings: method=%s loss=%s gtol=%.2e "
"max_nfev=%s x_scale=%s",
method,
config.loss,
config.gtol,
config.max_nfev
if config.max_nfev is not None
else f"auto({100 * n_params})",
config.x_scale,
)
nlsq_result = fitter.curve_fit( # type: ignore[union-attr]
f=jax_residual_fn,
xdata=xdata,
ydata=ydata,
p0=initial_params,
bounds=(lower_bounds, upper_bounds),
**_optimizer_kwargs(config, method),
)
wall_time = time.perf_counter() - start_time
# Normalise result via build_result_from_nlsq (single source of truth)
base = build_result_from_nlsq(
nlsq_result=nlsq_result,
parameter_names=self._parameter_names,
n_data=n_data,
wall_time=wall_time,
)
# Use residuals already stored in the optimizer result.
# ``build_result_from_nlsq`` extracts them from ``nlsq_result.fun``
# which is the final residual vector the optimizer converged to —
# re-evaluating at the same point is numerically identical and wastes
# one full N×N forward pass. Fall back to re-evaluation only when
# the optimizer did not expose residuals (e.g. streaming backends).
if base.residuals is not None:
final_residuals = base.residuals
else:
logger.debug(
"fit_jax: optimizer did not expose final residuals; "
"re-evaluating residual function"
)
final_residuals_jax = jax_residual_fn(
jnp.arange(n_data), *base.parameters
)
final_residuals = np.asarray(final_residuals_jax)
final_cost = 0.5 * float(np.sum(final_residuals**2))
n_dof = n_data - n_params
reduced_chi2: float | None = 2.0 * final_cost / n_dof if n_dof > 0 else None
success, message, reason = _assess_convergence(
fitted_params=base.parameters,
initial_params=initial_params,
reduced_chi2=reduced_chi2,
)
if not success:
logger.warning("fit_jax convergence check failed: %s", message)
return NLSQResult(
parameters=base.parameters,
parameter_names=self._parameter_names,
success=success,
message=message,
uncertainties=base.uncertainties,
covariance=base.covariance,
final_cost=final_cost,
reduced_chi_squared=reduced_chi2,
n_iterations=base.n_iterations,
n_function_evals=base.n_function_evals,
convergence_reason=reason,
residuals=final_residuals,
jacobian=base.jacobian,
wall_time_seconds=wall_time,
metadata=base.metadata,
)
except (RuntimeError, ValueError, TypeError) as exc:
logger.error("NLSQAdapter.fit_jax failed: %s", exc)
wall_time = time.perf_counter() - start_time
return build_failed_result(
parameter_names=self._parameter_names,
message=str(exc),
initial_params=initial_params,
wall_time=wall_time,
)
# ---------------------------------------------------------------------------
# NLSQWrapper — stable fallback with memory-aware strategy routing
# ---------------------------------------------------------------------------
[docs]
class NLSQWrapper(NLSQAdapterBase):
"""Stable fallback adapter with memory-aware strategy routing.
Selects between STANDARD, LARGE, and STREAMING optimization tiers based
on the estimated peak memory usage of the Jacobian matrix. Falls back
down the tier list if a higher tier fails.
Fallback order (descending resource intensity):
STREAMING → LARGE → STANDARD
Each tier is retried up to ``max_retries`` times before falling back.
"""
[docs]
def __init__(
self,
parameter_names: list[str],
enable_large_dataset: bool = True,
enable_recovery: bool = True,
max_retries: int = 3,
) -> None:
"""Initialise the wrapper.
Args:
parameter_names: Names of parameters being optimised, in order.
enable_large_dataset: Allow the LARGE tier when memory warrants it.
enable_recovery: Enable cross-tier fallback on failure.
max_retries: Maximum per-tier retries before falling back.
"""
self._parameter_names = parameter_names
self._enable_large_dataset = enable_large_dataset
self._enable_recovery = enable_recovery
self._max_retries = max(1, max_retries)
@property
def name(self) -> str:
return "nlsq.NLSQWrapper"
[docs]
def supports_bounds(self) -> bool:
return True
[docs]
def supports_jacobian(self) -> bool:
return True
[docs]
def fit(
self,
residual_fn: Callable[[np.ndarray], np.ndarray],
initial_params: np.ndarray,
bounds: tuple[np.ndarray, np.ndarray],
config: NLSQConfig,
jacobian_fn: Callable[[np.ndarray], np.ndarray] | None = None,
) -> NLSQResult:
"""Run NLSQ optimisation with automatic memory-based strategy routing.
Args:
residual_fn: Callable ``(params: ndarray) -> residuals: ndarray``.
initial_params: Starting parameter values.
bounds: ``(lower, upper)`` bound arrays.
config: Optimisation configuration.
jacobian_fn: Optional analytic Jacobian (for API compatibility).
Returns:
NLSQResult with fit results.
"""
start_time = time.perf_counter()
lower_bounds, upper_bounds = bounds
initial_params = np.clip(initial_params, lower_bounds, upper_bounds)
n_params = len(initial_params)
# --- Determine data size via a probe call ---
try:
probe = residual_fn(initial_params)
n_data = len(np.asarray(probe))
except Exception as exc: # noqa: BLE001
logger.error("NLSQWrapper: residual probe failed: %s", exc)
wall_time = time.perf_counter() - start_time
return build_failed_result(
parameter_names=self._parameter_names,
message=f"Residual probe failed: {exc}",
initial_params=initial_params,
wall_time=wall_time,
)
# --- Memory-based strategy selection ---
decision = select_nlsq_strategy(n_points=n_data, n_params=n_params)
logger.info(
"NLSQWrapper strategy: %s (%s)",
decision.strategy.value,
decision.reason,
)
# --- Build xdata/ydata for CurveFit-style API ---
xdata = np.arange(n_data, dtype=np.float64)
ydata = np.zeros(n_data, dtype=np.float64)
# jnp.array required: nlsq 0.6.12 calls func(xdata, *args) inside @jit.
def _wrapped(x: np.ndarray, *params: Any) -> Any:
return residual_fn(jnp.array(params, dtype=jnp.float64)) # type: ignore[arg-type]
method = config.method
if method == "dogbox":
logger.warning("Method 'dogbox' unsupported by nlsq; using 'trf'")
method = "trf"
loss = config.loss
logger.info(
"NLSQWrapper settings: method=%s ftol=%.2e xtol=%.2e gtol=%.2e "
"max_nfev=%s x_scale=%r (loss=%r not applied; "
"STREAMING tier ignores tolerances)",
method,
config.ftol,
config.xtol,
config.gtol,
config.max_nfev
if config.max_nfev is not None
else f"auto({100 * n_params})",
config.x_scale,
loss,
)
# --- Tier ordering: initial strategy → fallback cascade ---
tiers = self._build_tier_list(decision.strategy)
last_exc: Exception | None = None
for tier in tiers:
result = self._try_tier(
tier=tier,
wrapped_fn=_wrapped,
xdata=xdata,
ydata=ydata,
initial_params=initial_params,
bounds=(lower_bounds, upper_bounds),
n_data=n_data,
n_params=n_params,
method=method,
loss=loss,
config=config,
start_time=start_time,
)
if result is not None:
return result
# Result is None → this tier exhausted all retries; try next
last_exc = RuntimeError(
f"Tier {tier.value} failed after {self._max_retries} retries"
)
if not self._enable_recovery:
break
# All tiers failed
wall_time = time.perf_counter() - start_time
message = str(last_exc) if last_exc else "All NLSQ tiers failed"
logger.error("NLSQWrapper: %s", message)
return build_failed_result(
parameter_names=self._parameter_names,
message=message,
initial_params=initial_params,
wall_time=wall_time,
)
def _build_tier_list(self, initial_strategy: NLSQStrategy) -> list[NLSQStrategy]:
"""Return ordered list of tiers to attempt, starting from initial_strategy."""
all_tiers = [NLSQStrategy.STREAMING, NLSQStrategy.LARGE, NLSQStrategy.STANDARD]
# Start from the selected strategy and work downward
try:
start_idx = all_tiers.index(initial_strategy)
except ValueError:
start_idx = len(all_tiers) - 1 # default to STANDARD
tiers = all_tiers[start_idx:]
# Drop LARGE if large-dataset support is disabled
if not self._enable_large_dataset and NLSQStrategy.LARGE in tiers:
tiers = [t for t in tiers if t != NLSQStrategy.LARGE]
return tiers
def _try_tier(
self,
tier: NLSQStrategy,
wrapped_fn: Callable[..., np.ndarray],
xdata: np.ndarray,
ydata: np.ndarray,
initial_params: np.ndarray,
bounds: tuple[np.ndarray, np.ndarray],
n_data: int,
n_params: int,
method: str,
loss: str,
start_time: float,
config: NLSQConfig | None = None,
) -> NLSQResult | None:
"""Attempt a single tier up to max_retries times.
Returns:
NLSQResult on success, or None if all retries failed.
"""
lower_bounds, upper_bounds = bounds
for attempt in range(self._max_retries):
try:
raw_result = self._call_tier(
tier=tier,
wrapped_fn=wrapped_fn,
xdata=xdata,
ydata=ydata,
p0=initial_params,
lower_bounds=lower_bounds,
upper_bounds=upper_bounds,
n_data=n_data,
n_params=n_params,
method=method,
loss=loss,
config=config,
)
wall_time = time.perf_counter() - start_time
result = build_result_from_nlsq(
nlsq_result=raw_result,
parameter_names=self._parameter_names,
n_data=n_data,
wall_time=wall_time,
metadata={"strategy": tier.value, "attempt": attempt},
)
# Apply convergence heuristics (same as NLSQAdapter)
success, message, reason = _assess_convergence(
fitted_params=result.parameters,
initial_params=initial_params,
reduced_chi2=result.reduced_chi_squared,
)
if not success:
logger.warning(
"NLSQWrapper: tier %s convergence check failed: %s",
tier.value,
message,
)
result = NLSQResult(
parameters=result.parameters,
parameter_names=self._parameter_names,
success=False,
message=message,
uncertainties=result.uncertainties,
covariance=result.covariance,
final_cost=result.final_cost,
reduced_chi_squared=result.reduced_chi_squared,
n_iterations=result.n_iterations,
n_function_evals=result.n_function_evals,
convergence_reason=reason,
residuals=result.residuals,
jacobian=result.jacobian,
wall_time_seconds=wall_time,
metadata=result.metadata,
)
logger.info(
"NLSQWrapper: tier %s succeeded on attempt %d/%d",
tier.value,
attempt + 1,
self._max_retries,
)
return result
except Exception as exc: # noqa: BLE001
logger.warning(
"NLSQWrapper: tier %s attempt %d/%d failed: %s",
tier.value,
attempt + 1,
self._max_retries,
exc,
)
return None
def _call_tier(
self,
tier: NLSQStrategy,
wrapped_fn: Callable[..., np.ndarray],
xdata: np.ndarray,
ydata: np.ndarray,
p0: np.ndarray,
lower_bounds: np.ndarray,
upper_bounds: np.ndarray,
n_data: int,
n_params: int,
method: str,
loss: str,
config: NLSQConfig | None = None,
) -> Any:
"""Dispatch a single call to the appropriate nlsq function/class.
``ftol``/``xtol``/``gtol``/``x_scale``/``max_nfev`` are propagated to
STANDARD and LARGE tiers via ``**kwargs`` (both ``nlsq.curve_fit`` and
``curve_fit_large`` forward unknown kwargs to scipy ``least_squares``).
``loss`` is intentionally omitted on all tiers: the NLSQWrapper path
wraps residuals as a plain numpy function, so robust-loss kernels would
re-enter JAX tracing and raise ``TracerArrayConversionError``.
STREAMING tier ignores tolerances (fixed ``AdaptiveHybridStreamingOptimizer``
signature).
"""
# Solver kwargs propagated to STANDARD and LARGE tiers.
solver_kwargs: dict[str, Any] = {}
if config is not None:
solver_kwargs["ftol"] = config.ftol
solver_kwargs["xtol"] = config.xtol
solver_kwargs["gtol"] = config.gtol
solver_kwargs["x_scale"] = config.x_scale
if config.max_nfev is not None:
solver_kwargs["max_nfev"] = config.max_nfev
if tier == NLSQStrategy.STREAMING:
if not STREAMING_AVAILABLE or AdaptiveHybridStreamingOptimizer is None:
raise RuntimeError(
"AdaptiveHybridStreamingOptimizer not available in this nlsq build"
)
optimizer = AdaptiveHybridStreamingOptimizer()
return optimizer.fit(
data_source=(xdata, ydata),
func=wrapped_fn,
p0=p0,
bounds=(lower_bounds, upper_bounds),
)
if tier == NLSQStrategy.LARGE:
return curve_fit_large(
f=wrapped_fn,
xdata=xdata,
ydata=ydata,
p0=p0,
bounds=(lower_bounds, upper_bounds),
**solver_kwargs,
)
# STANDARD tier. loss intentionally omitted — see docstring.
_ = loss
return curve_fit( # type: ignore[call-arg, arg-type]
f=wrapped_fn,
xdata=xdata,
ydata=ydata,
p0=p0,
bounds=(lower_bounds, upper_bounds),
method=method, # type: ignore[arg-type]
**solver_kwargs,
)
# ---------------------------------------------------------------------------
# Module-level factory functions (homodyne parity)
# ---------------------------------------------------------------------------
[docs]
def get_adapter(config: AdapterConfig | None = None) -> NLSQAdapter:
"""Factory function to get an NLSQAdapter instance.
Args:
config: Adapter configuration. If ``None``, uses defaults.
Returns:
NLSQAdapter instance.
"""
return NLSQAdapter(config=config)
[docs]
def is_adapter_available() -> bool:
"""Check if NLSQAdapter can be used.
Returns:
True if the nlsq CurveFit class is importable.
"""
try:
from nlsq import CurveFit as _CurveFit # noqa: F401
return True
except ImportError:
return False
# ---------------------------------------------------------------------------
# Disambiguation alias
# ---------------------------------------------------------------------------
#
# ``adapter.NLSQWrapper`` is the LOW-LEVEL wrapper (residual_fn → result).
# ``wrapper.NLSQWrapper`` (re-exported as ``heterodyne.optimization.nlsq.NLSQWrapper``)
# is the HIGH-LEVEL stable-fallback adapter (data + config → result) that
# delegates internally to this low-level class.
#
# The two share the homodyne-parity name "NLSQWrapper" but have different
# signatures and roles. Internal callers (``core.py``, ``fallback_chain.py``,
# ``wrapper.py``) should prefer the explicit alias below to make the routing
# unambiguous at the import site.
LowLevelNLSQWrapper = NLSQWrapper
__all__ = [
"AdapterConfig",
"CachedModel",
"ModelCacheKey",
"NLSQAdapter",
"NLSQWrapper",
"STREAMING_AVAILABLE",
"clear_model_cache",
"get_adapter",
"get_cache_stats",
"get_or_create_fitter",
"get_or_create_model",
"is_adapter_available",
]