"""Pre-computed physics factors for efficient correlation computation."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import jax.numpy as jnp
from heterodyne.utils.logging import get_logger
if TYPE_CHECKING:
pass
logger = get_logger(__name__)
[docs]
@dataclass
class PhysicsFactors:
"""Pre-computed factors that don't depend on fit parameters.
These are computed once from experimental setup and reused across
all optimization iterations for efficiency.
"""
# Time arrays
t: jnp.ndarray # Time array, shape (N,)
# Scattering
q: float # Wavevector magnitude
q_squared: float # q²
# Temporal
dt: float # Time step
n_times: int # Number of time points
# Geometry
phi_angle: float # Detector phi angle (degrees)
[docs]
def __post_init__(self) -> None:
"""Validate factors."""
if self.q <= 0:
raise ValueError(f"q must be positive, got {self.q}")
if self.dt <= 0:
raise ValueError(f"dt must be positive, got {self.dt}")
@property
def time_extent(self) -> float:
"""Total time span."""
return float(self.t[-1] - self.t[0])
[docs]
def get_q_cosine(self, phi0: float = 0.0) -> jnp.ndarray:
"""Get q * cos(phi_total) for cross-term phase.
Args:
phi0: Additional angle from fit parameters
Returns:
q * cos(phi_angle + phi0) as JAX scalar
"""
total_phi_rad = jnp.deg2rad(self.phi_angle + phi0)
return self.q * jnp.cos(total_phi_rad)
[docs]
def create_physics_factors(
n_times: int,
dt: float,
q: float,
phi_angle: float = 0.0,
t_start: float = 0.0,
) -> PhysicsFactors:
"""Create physics factors from experimental parameters.
Args:
n_times: Number of time points
dt: Time step
q: Scattering wavevector magnitude
phi_angle: Detector phi angle (degrees)
t_start: Starting time (default 0)
Returns:
PhysicsFactors instance
"""
# Create time array
t = jnp.arange(n_times) * dt + t_start
return PhysicsFactors(
t=t,
q=float(q),
q_squared=float(q * q),
dt=float(dt),
n_times=n_times,
phi_angle=float(phi_angle),
)
[docs]
def create_physics_factors_from_config(config: dict) -> PhysicsFactors:
"""Create physics factors from configuration dictionary.
Reads from ``analyzer_parameters`` (canonical) with fallback to legacy
``temporal``/``scattering`` top-level sections for backwards compatibility.
Args:
config: Configuration with ``analyzer_parameters`` or legacy
``temporal``/``scattering`` sections.
Returns:
PhysicsFactors instance
"""
ap = config.get("analyzer_parameters", {})
temporal = config.get("temporal", {})
scattering = config.get("scattering", {})
# Prefer analyzer_parameters; fall back to legacy sections
dt = float(ap.get("dt", temporal.get("dt", 1.0)))
if "start_frame" in ap:
start_frame = int(ap["start_frame"])
end_frame = int(ap["end_frame"])
n_times = end_frame - start_frame + 1
t_start = dt # relative time within window: first usable frame at 1×dt
else:
n_times = int(temporal.get("time_length", 1000))
t_start = float(temporal.get("t_start", dt))
ap_scat = ap.get("scattering", {})
q = float(ap_scat.get("wavevector_q", scattering.get("wavevector_q", 0.01)))
logger.debug(
"Physics factors: n_times=%d, dt=%.4e, q=%.4f, t_start=%.4e",
n_times,
dt,
q,
float(t_start),
)
return create_physics_factors(
n_times=n_times,
dt=dt,
q=q,
phi_angle=0.0, # Set per-fit
t_start=float(t_start),
)
[docs]
@dataclass
class CachedMatrices:
"""Cached matrices that depend only on time grid.
These are expensive to recompute and don't change during fitting.
"""
# Time difference matrix: |t1 - t2|
time_diff: jnp.ndarray
# Age matrix: (t1 + t2) / 2
mean_time: jnp.ndarray
# Indices for upper/lower triangular
triu_indices: tuple[jnp.ndarray, jnp.ndarray]
tril_indices: tuple[jnp.ndarray, jnp.ndarray]
[docs]
def create_cached_matrices(factors: PhysicsFactors) -> CachedMatrices:
"""Create cached matrices from physics factors.
Args:
factors: PhysicsFactors instance
Returns:
CachedMatrices instance
"""
t1, t2 = jnp.meshgrid(factors.t, factors.t, indexing="ij")
n = factors.n_times
return CachedMatrices(
time_diff=jnp.abs(t1 - t2),
mean_time=(t1 + t2) / 2,
triu_indices=jnp.triu_indices(n),
tril_indices=jnp.tril_indices(n),
)