Source code for heterodyne.core.heterodyne_model

"""Main heterodyne model wrapper class."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import jax
import jax.numpy as jnp
import numpy as np

from heterodyne.config.parameter_manager import ParameterManager
from heterodyne.config.parameter_names import ALL_PARAM_NAMES
from heterodyne.core.jax_backend import compute_c2_heterodyne, compute_residuals
from heterodyne.core.models import TwoComponentModel
from heterodyne.core.physics_factors import PhysicsFactors, create_physics_factors
from heterodyne.core.scaling_utils import PerAngleScaling, ScalingConfig

if TYPE_CHECKING:
    pass


[docs] @dataclass class HeterodyneModel: """Main heterodyne correlation model with stateful parameter management. This class provides a convenient interface for: - Managing model parameters through ParameterManager - Computing correlation matrices - Computing residuals for fitting - Accessing pre-computed physics factors Example: >>> model = HeterodyneModel.from_config(config) >>> c2 = model.compute_correlation(phi_angle=45.0) >>> residuals = model.compute_residuals(c2_data, phi_angle=45.0) """ # Core model _model: TwoComponentModel = field(default_factory=TwoComponentModel) # Parameter management param_manager: ParameterManager = field(default_factory=ParameterManager) # Physics factors (pre-computed from config) _factors: PhysicsFactors | None = field(default=None) # Per-angle scaling (contrast/offset as fitted parameters) scaling: PerAngleScaling = field(default_factory=PerAngleScaling) # Cached time array _t: jnp.ndarray | None = field(default=None)
[docs] @classmethod def from_config(cls, config: dict[str, Any]) -> HeterodyneModel: """Create model from configuration dictionary. Args: config: Configuration with temporal, scattering, and parameters sections Returns: Configured HeterodyneModel """ param_manager = ParameterManager.from_config(config) # Read from analyzer_parameters (canonical) with legacy fallback ap = config.get("analyzer_parameters", {}) temporal = config.get("temporal", {}) scattering = config.get("scattering", {}) dt = float(ap.get("dt", temporal.get("dt", 1.0))) if "start_frame" in ap: start_frame = int(ap["start_frame"]) end_frame = int(ap["end_frame"]) n_times = end_frame - start_frame + 1 t_start = dt # relative time within window: first usable frame at 1×dt else: n_times = int(temporal.get("time_length", 1000)) t_start = float(temporal.get("t_start", dt)) ap_scat = ap.get("scattering", {}) q = float(ap_scat.get("wavevector_q", scattering.get("wavevector_q", 0.01))) factors = create_physics_factors( n_times=n_times, dt=dt, q=q, phi_angle=0.0, t_start=float(t_start), ) # Per-angle scaling config. # Priority: parameters.scaling values (via ParameterSpace) > scaling # section > registry defaults. This ensures YAML parameter overrides # propagate to the PerAngleScaling object (homodyne parity). scaling_cfg = config.get("scaling", {}) space_scaling = param_manager.space.scaling_values scaling = PerAngleScaling.from_config( ScalingConfig( n_angles=int(scaling_cfg.get("n_angles", 1)), mode=str(scaling_cfg.get("mode", "constant")), initial_contrast=float( scaling_cfg.get("initial_contrast", space_scaling["contrast"]) ), initial_offset=float( scaling_cfg.get("initial_offset", space_scaling["offset"]) ), ) ) return cls( _model=TwoComponentModel(), param_manager=param_manager, _factors=factors, scaling=scaling, _t=factors.t, )
@property def n_params(self) -> int: """Total number of model parameters (14).""" return 14 @property def n_varying(self) -> int: """Number of varying parameters.""" return self.param_manager.n_varying @property def param_names(self) -> tuple[str, ...]: """All parameter names in canonical order.""" return ALL_PARAM_NAMES @property def varying_names(self) -> list[str]: """Names of varying parameters.""" return self.param_manager.varying_names @property def q(self) -> float: """Scattering wavevector magnitude.""" if self._factors is None: raise ValueError("Physics factors not initialized") return self._factors.q @property def dt(self) -> float: """Time step.""" if self._factors is None: raise ValueError("Physics factors not initialized") return self._factors.dt @property def t(self) -> jnp.ndarray: """Time array.""" if self._t is None: raise ValueError("Time array not initialized") return self._t @property def n_times(self) -> int: """Number of time points.""" if self._factors is None: raise ValueError("Physics factors not initialized") return self._factors.n_times
[docs] def sync_time_axis(self, t: np.ndarray) -> None: """Trim model time axis to match post-exclusion data length. The data pipeline may remove leading time points (e.g. t=0 singularity exclusion), shrinking the data array. This method trims the same number of leading points from the model's own seconds-based time axis (computed from start_frame and dt) so shapes align without discarding the correct absolute-time values. """ n_data = len(t) if self._t is None: raise ValueError("Model time axis not initialized; call from_config first") n_model = len(self._t) if n_data < n_model: self._t = self._t[n_model - n_data :] elif n_data > n_model: # More data points than model — expand using the model's dt spacing extra = jnp.arange(1, n_data - n_model + 1, dtype=jnp.float64) * self.dt self._t = jnp.concatenate([self._t, self._t[-1:] + extra])
[docs] def get_params(self) -> np.ndarray: """Get current full parameter array. Returns: Array of shape (14,) """ return self.param_manager.get_full_values()
[docs] def get_params_dict(self) -> dict[str, float]: """Get current parameters as dictionary.""" return self.param_manager.get_parameter_dict()
[docs] def set_params(self, params: np.ndarray | dict[str, float]) -> None: """Set parameter values. Args: params: Either array of shape (14,) or dict with param names """ self.param_manager.update_values(params)
[docs] def compute_correlation( self, phi_angle: float = 0.0, params: np.ndarray | None = None, contrast: float | None = None, offset: float | None = None, angle_idx: int = 0, ) -> jnp.ndarray: """Compute two-time correlation matrix. Args: phi_angle: Detector phi angle (degrees) params: Optional parameter array (uses stored values if None) contrast: Speckle contrast override. If None, uses per-angle scaling. offset: Baseline offset override. If None, uses per-angle scaling. angle_idx: Angle index for per-angle scaling lookup (0-based). Returns: Correlation matrix c2(t1, t2), shape (N, N) """ if params is None: params = self.get_params() # Use per-angle scaling unless explicitly overridden if contrast is None or offset is None: sc_contrast, sc_offset = self.scaling.get_for_angle(angle_idx) if contrast is None: contrast = sc_contrast if offset is None: offset = sc_offset return compute_c2_heterodyne( jnp.asarray(params), self.t, self.q, self.dt, phi_angle, contrast, offset, )
[docs] def compute_residuals( self, c2_data: np.ndarray | jnp.ndarray, phi_angle: float = 0.0, params: np.ndarray | None = None, weights: np.ndarray | jnp.ndarray | None = None, contrast: float | None = None, offset: float | None = None, angle_idx: int = 0, ) -> jnp.ndarray: """Compute residuals between model and data. Args: c2_data: Experimental correlation data phi_angle: Detector phi angle params: Optional parameter array weights: Optional weights (1/sigma²) contrast: Speckle contrast override. If None, uses per-angle scaling. offset: Baseline offset override. If None, uses per-angle scaling. angle_idx: Angle index for per-angle scaling lookup (0-based). Returns: Flattened residual array """ if params is None: params = self.get_params() if contrast is None or offset is None: sc_contrast, sc_offset = self.scaling.get_for_angle(angle_idx) if contrast is None: contrast = sc_contrast if offset is None: offset = sc_offset return compute_residuals( jnp.asarray(params), self.t, self.q, self.dt, phi_angle, jnp.asarray(c2_data), jnp.asarray(weights) if weights is not None else None, contrast, offset, )
[docs] def compute_g1_reference(self, params: np.ndarray | None = None) -> jnp.ndarray: """Compute reference g1 correlation. Args: params: Optional parameter array Returns: g1_ref array, shape (N,) """ if params is None: params = self.get_params() return self._model.compute_g1_reference(params, self.t, self.q)
[docs] def compute_g1_sample(self, params: np.ndarray | None = None) -> jnp.ndarray: """Compute sample g1 correlation. Args: params: Optional parameter array Returns: g1_sample array, shape (N,) """ if params is None: params = self.get_params() return self._model.compute_g1_sample(params, self.t, self.q)
[docs] def compute_fraction(self, params: np.ndarray | None = None) -> jnp.ndarray: """Compute sample fraction evolution. Args: params: Optional parameter array Returns: f_sample array, shape (N,) """ if params is None: params = self.get_params() return self._model.compute_fraction(params, self.t)
[docs] def create_residual_function( self, c2_data: np.ndarray | jnp.ndarray, phi_angle: float, weights: np.ndarray | jnp.ndarray | None = None, angle_idx: int = 0, ) -> Any: """Create a residual function for optimization. Returns a function that takes varying parameters and returns residuals. Args: c2_data: Experimental correlation data phi_angle: Detector phi angle weights: Optional weights angle_idx: Index into per-angle scaling for contrast/offset lookup. Returns: Callable that maps varying params -> residuals """ c2_jax = jnp.asarray(c2_data) weights_jax = ( jnp.asarray(weights) if weights is not None else jnp.ones_like(c2_jax) ) t = self.t q = self.q dt = self.dt contrast_val, offset_val = self.scaling.get_for_angle(angle_idx) varying_idx_jax = jnp.array(self.param_manager.varying_indices) fixed_values_jax = jnp.array(self.param_manager.get_full_values()) @jax.jit def residual_fn(varying_params: jnp.ndarray) -> jnp.ndarray: # Reconstruct full params full_params = fixed_values_jax.at[varying_idx_jax].set(varying_params) return compute_residuals( full_params, t, q, dt, phi_angle, c2_jax, weights_jax, contrast_val, offset_val, ) return residual_fn
[docs] def summary(self) -> str: """Return summary of model configuration. Returns: Multi-line summary string """ lines = [ "HeterodyneModel Summary", "=" * 40, f"Time points: {self.n_times}", f"Time step: {self.dt}", f"Wavevector q: {self.q}", f"Total params: {self.n_params}", f"Varying params: {self.n_varying}", "", "Current Parameters:", "-" * 40, ] params = self.get_params_dict() for name in ALL_PARAM_NAMES: vary = "vary" if name in self.varying_names else "fixed" lines.append(f" {name:18s}: {params[name]:12.4e} ({vary})") return "\n".join(lines)