"""Model class hierarchy for heterodyne correlation analysis."""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import jax.numpy as jnp
import numpy as np
from heterodyne.config.parameter_names import ALL_PARAM_NAMES
from heterodyne.core.jax_backend import compute_c2_heterodyne
if TYPE_CHECKING:
pass
[docs]
class HeterodyneModelBase(ABC):
"""Abstract base class for heterodyne models."""
@property
@abstractmethod
def n_params(self) -> int:
"""Number of model parameters."""
...
@property
@abstractmethod
def param_names(self) -> tuple[str, ...]:
"""Parameter names in order."""
...
[docs]
@abstractmethod
def compute_correlation(
self,
params: jnp.ndarray,
t: jnp.ndarray,
q: float,
dt: float,
phi_angle: float,
contrast: float = 1.0,
offset: float = 1.0,
) -> jnp.ndarray:
"""Compute model correlation matrix.
Args:
params: Parameter array
t: Time array
q: Wavevector
dt: Time step
phi_angle: Detector phi angle (degrees)
contrast: Speckle contrast (beta), default 1.0
offset: Baseline offset, default 1.0
Returns:
Correlation matrix
"""
...
[docs]
@abstractmethod
def get_default_params(self) -> np.ndarray:
"""Get default parameter values."""
...
[docs]
@dataclass
class TwoComponentModel(HeterodyneModelBase):
"""Two-component heterodyne correlation model.
Implements the 14-parameter model:
- Reference transport (3): D0_ref, alpha_ref, D_offset_ref
- Sample transport (3): D0_sample, alpha_sample, D_offset_sample
- Velocity (3): v0, beta, v_offset
- Fraction (4): f0, f1, f2, f3
- Angle (1): phi0
"""
_defaults: dict[str, float] = field(default_factory=dict)
[docs]
def __post_init__(self) -> None:
"""Set default parameter values."""
if not self._defaults:
self._defaults = {
"D0_ref": 1e4,
"alpha_ref": 0.0,
"D_offset_ref": 0.0,
"D0_sample": 1e4,
"alpha_sample": 0.0,
"D_offset_sample": 0.0,
"v0": 1e3,
"beta": 0.0,
"v_offset": 0.0,
"f0": 0.5,
"f1": 0.0,
"f2": 0.0,
"f3": 0.0,
"phi0": 0.0,
}
@property
def n_params(self) -> int:
"""Number of parameters (14)."""
return 14
@property
def param_names(self) -> tuple[str, ...]:
"""Parameter names in canonical order."""
return ALL_PARAM_NAMES
[docs]
def compute_correlation(
self,
params: jnp.ndarray,
t: jnp.ndarray,
q: float,
dt: float,
phi_angle: float,
contrast: float = 1.0,
offset: float = 1.0,
) -> jnp.ndarray:
"""Compute two-time heterodyne correlation.
Args:
params: Parameter array, shape (14,)
t: Time array
q: Scattering wavevector
dt: Time step
phi_angle: Detector phi angle (degrees)
contrast: Speckle contrast (beta), default 1.0
offset: Baseline offset, default 1.0
Returns:
Correlation matrix c2(t1, t2), shape (N, N)
"""
return compute_c2_heterodyne(params, t, q, dt, phi_angle, contrast, offset)
[docs]
def get_default_params(self) -> np.ndarray:
"""Get default parameter values as array."""
return np.array([self._defaults[name] for name in ALL_PARAM_NAMES])
[docs]
def params_to_dict(self, params: np.ndarray | jnp.ndarray) -> dict[str, float]:
"""Convert parameter array to dictionary.
Args:
params: Parameter array, shape (14,)
Returns:
Dict mapping names to values
"""
return {name: float(params[i]) for i, name in enumerate(ALL_PARAM_NAMES)}
[docs]
def dict_to_params(self, param_dict: dict[str, float]) -> np.ndarray:
"""Convert parameter dictionary to array.
Args:
param_dict: Dict with parameter names as keys
Returns:
Parameter array, shape (14,)
"""
return np.array(
[param_dict.get(name, self._defaults[name]) for name in ALL_PARAM_NAMES]
)
[docs]
def compute_g1_reference(
self,
params: np.ndarray | jnp.ndarray,
t: jnp.ndarray,
q: float,
) -> jnp.ndarray:
"""Compute reference g1 correlation only (1D visualization helper).
.. note::
Uses pointwise g1(t) = exp(-q²J(t)), which does not represent
the two-time integral physics. For production correlation, use
compute_correlation which uses the integral formulation.
Args:
params: Full parameter array
t: Time array
q: Wavevector
Returns:
g1_ref array
"""
D0, alpha, offset = params[0], params[1], params[2]
# Use jnp.where instead of jnp.maximum to preserve gradients at the
# t=0 floor (jnp.maximum zeros the gradient when t < 1e-10).
t_safe = jnp.where(t > 1e-10, t, 1e-10)
J = D0 * jnp.where(t > 0, jnp.power(t_safe, alpha), 0.0) + offset
# Physical positivity floor. jnp.maximum's JVP averages the two
# tangents at the kink (0.5x), matching the FD subgradient of the
# offset; jnp.where(J >= 0) would route the full tangent (2x) and
# break the FD↔autodiff agreement pinned by test_gradient_finite_
# difference. Allow-listed in test_no_gradient_killing_clip.py.
J = jnp.maximum(J, 0.0)
return jnp.exp(-q * q * J)
[docs]
def compute_g1_sample(
self,
params: np.ndarray | jnp.ndarray,
t: jnp.ndarray,
q: float,
) -> jnp.ndarray:
"""Compute sample g1 correlation only (1D visualization helper).
.. note::
Uses pointwise g1(t) = exp(-q²J(t)), which does not represent
the two-time integral physics. For production correlation, use
compute_correlation which uses the integral formulation.
Args:
params: Full parameter array
t: Time array
q: Wavevector
Returns:
g1_sample array
"""
D0, alpha, offset = params[3], params[4], params[5]
# Use jnp.where instead of jnp.maximum to preserve gradients at the
# t=0 floor (jnp.maximum zeros the gradient when t < 1e-10).
t_safe = jnp.where(t > 1e-10, t, 1e-10)
J = D0 * jnp.where(t > 0, jnp.power(t_safe, alpha), 0.0) + offset
# Physical positivity floor. jnp.maximum's JVP averages the two
# tangents at the kink (0.5x), matching the FD subgradient of the
# offset; jnp.where(J >= 0) would route the full tangent (2x) and
# break the FD↔autodiff agreement pinned by test_gradient_finite_
# difference. Allow-listed in test_no_gradient_killing_clip.py.
J = jnp.maximum(J, 0.0)
return jnp.exp(-q * q * J)
[docs]
def compute_fraction(
self,
params: np.ndarray | jnp.ndarray,
t: jnp.ndarray,
) -> jnp.ndarray:
"""Compute sample fraction only.
Args:
params: Full parameter array
t: Time array
Returns:
f_sample array in [0, 1]
"""
f0, f1, f2, f3 = params[9], params[10], params[11], params[12]
exponent = jnp.clip(f1 * (t - f2), -100, 100)
return jnp.clip(f0 * jnp.exp(exponent) + f3, 0.0, 1.0)
[docs]
@dataclass
class ReducedModel(HeterodyneModelBase):
"""Reduced heterodyne model with a subset of active parameters.
Inactive parameters are held fixed at their canonical default values.
Useful for simplified analysis modes (e.g., reference-only diffusion).
Args:
_active_params: Ordered tuple of parameter names that are free to vary.
"""
_active_params: tuple[str, ...]
# Full default values for all 14 parameters (canonical defaults)
_FULL_DEFAULTS: dict[str, float] = field(
default_factory=lambda: {
"D0_ref": 1e4,
"alpha_ref": 0.0,
"D_offset_ref": 0.0,
"D0_sample": 1e4,
"alpha_sample": 0.0,
"D_offset_sample": 0.0,
"v0": 1e3,
"beta": 0.0,
"v_offset": 0.0,
"f0": 0.5,
"f1": 0.0,
"f2": 0.0,
"f3": 0.0,
"phi0": 0.0,
}
)
[docs]
def __post_init__(self) -> None:
"""Validate active params and precompute expansion constants."""
invalid = [n for n in self._active_params if n not in ALL_PARAM_NAMES]
if invalid:
raise ValueError(
f"Unknown parameter names: {invalid}. "
f"Valid names: {list(ALL_PARAM_NAMES)}"
)
# Precompute template and index mapping for _expand_to_full
object.__setattr__(
self,
"_template",
jnp.array([self._FULL_DEFAULTS[name] for name in ALL_PARAM_NAMES]),
)
idx_list = [ALL_PARAM_NAMES.index(name) for name in self._active_params]
object.__setattr__(self, "_active_indices", tuple(idx_list))
object.__setattr__(
self,
"_active_indices_array",
jnp.array(idx_list, dtype=jnp.int32),
)
@property
def n_params(self) -> int:
"""Number of active (free) parameters."""
return len(self._active_params)
@property
def param_names(self) -> tuple[str, ...]:
"""Active parameter names in order."""
return self._active_params
[docs]
def get_default_params(self) -> np.ndarray:
"""Get default values for active parameters only."""
return np.array([self._FULL_DEFAULTS[name] for name in self._active_params])
def _expand_to_full(self, params: jnp.ndarray) -> jnp.ndarray:
"""Expand active-parameter array to full 14-element array.
Uses precomputed template and index mapping for efficiency.
Inactive parameters retain their canonical defaults.
Args:
params: Active-parameter array, shape (n_params,)
Returns:
Full parameter array, shape (14,)
"""
return self._template.at[self._active_indices_array].set(params) # type: ignore[attr-defined,no-any-return]
[docs]
def compute_correlation(
self,
params: jnp.ndarray,
t: jnp.ndarray,
q: float,
dt: float,
phi_angle: float,
contrast: float = 1.0,
offset: float = 1.0,
) -> jnp.ndarray:
"""Compute model correlation from reduced parameter set.
Inactive parameters are held at canonical defaults.
Args:
params: Active-parameter array, shape (n_params,)
t: Time array
q: Scattering wavevector
dt: Time step
phi_angle: Detector phi angle (degrees)
contrast: Speckle contrast (beta), default 1.0
offset: Baseline offset, default 1.0
Returns:
Correlation matrix c2(t1, t2), shape (N, N)
"""
full_params = self._expand_to_full(params)
return compute_c2_heterodyne(full_params, t, q, dt, phi_angle, contrast, offset)
# ---------------------------------------------------------------------------
# Analysis mode registry
# ---------------------------------------------------------------------------
ANALYSIS_MODES: dict[str, tuple[str, ...]] = {
"static_ref": ("D0_ref", "alpha_ref", "D_offset_ref"),
"static_both": (
"D0_ref",
"alpha_ref",
"D_offset_ref",
"D0_sample",
"alpha_sample",
"D_offset_sample",
),
"two_component": ALL_PARAM_NAMES,
}
[docs]
def create_model(mode: str) -> HeterodyneModelBase:
"""Factory function that returns a model for the requested analysis mode.
Args:
mode: One of ``"static_ref"``, ``"static_both"``, ``"two_component"``.
Returns:
``TwoComponentModel`` for ``"two_component"``;
``ReducedModel`` for all other recognised modes.
Raises:
ValueError: If *mode* is not a recognised analysis mode.
"""
if mode not in ANALYSIS_MODES:
valid = ", ".join(sorted(ANALYSIS_MODES))
raise ValueError(f"Unknown analysis mode '{mode}'. Valid modes: {valid}")
if mode == "two_component":
return TwoComponentModel()
return ReducedModel(_active_params=ANALYSIS_MODES[mode])
# Default model instance
DEFAULT_MODEL = TwoComponentModel()