"""JAX-accelerated computational backend for heterodyne correlation.
This module provides JIT-compiled functions for computing the two-component
heterodyne correlation function using the integral formulation (PNAS Eq. S-95).
All functions are designed to be stateless and compatible with JAX
transformations (jit, vmap, grad).
The correlation is computed as
``c2 = offset + contrast × [ref + sample + cross] / f²``,
where transport terms use the integral of the rate J(t):
``half_tr[i,j] = exp(-½q² × |∫ J_rate(t') dt'|)``.
The 14 model parameters in canonical order are
D0_ref, alpha_ref, D_offset_ref,
D0_sample, alpha_sample, D_offset_sample,
v0, beta, v_offset, f0, f1, f2, f3, phi0.
Three evaluation strategies are exposed (all dispatch through
:func:`heterodyne.core.physics_kernel.compute_c2_unified`):
- :func:`compute_c2_heterodyne` — full ``(N, N)`` meshgrid output for NLSQ.
- ``compute_c2_elementwise`` (re-exported from :mod:`heterodyne.core.physics_cmc`)
— sharded per-pair ``(n_pairs,)`` output for CMC.
- :func:`compute_c2_heterodyne_pooled` — per-pooled-point ``(n_total,)``
output for joint multi-phi CMC (Phase 4; replaces the older
vmap+gather pattern in :func:`compute_c2_heterodyne_multiphi`).
"""
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING
import jax
import jax.numpy as jnp
from heterodyne.core.physics_utils import (
compute_transport_rate,
compute_velocity_rate,
create_time_integral_matrix,
safe_exp,
smooth_abs,
smooth_clip,
trapezoid_cumsum,
)
if TYPE_CHECKING:
pass
[docs]
@partial(jax.jit, static_argnames=("n_times",))
def compute_transport_jit(
t: jnp.ndarray,
D0: float,
alpha: float,
offset: float,
n_times: int,
) -> jnp.ndarray:
"""JIT-compiled pointwise transport coefficient computation.
J(t) = D0 * t^alpha + offset
.. deprecated::
Pointwise approximation — not used in production correlation.
Production code uses compute_transport_integral_matrix for the
integral formulation (PNAS Eq. S-95). Retained for test
compatibility and 1D visualization helpers.
Args:
t: Time array
D0: Transport prefactor
alpha: Transport exponent
offset: Constant offset
n_times: Number of time points (static for JIT)
Returns:
Transport coefficient array
"""
# Use jnp.where instead of jnp.maximum to preserve gradients at the t=0
# floor (jnp.maximum zeros the gradient when t < 1e-10).
t_safe = jnp.where(t > 1e-10, t, 1e-10)
t_power = jnp.power(t_safe, alpha)
t_power = jnp.where(t > 0, t_power, 0.0)
return D0 * t_power + offset
[docs]
@jax.jit
def compute_g1_transport(
J: jnp.ndarray,
q: float,
) -> jnp.ndarray:
"""JIT-compiled pointwise g1 correlation from transport coefficient.
g1(t) = exp(-q² * J(t))
.. deprecated::
Pointwise approximation — not used in production correlation.
Production code uses the integral formulation via
compute_transport_integral_matrix. Retained for test
compatibility and 1D visualization helpers.
Args:
J: Transport coefficient array
q: Scattering wavevector
Returns:
g1 correlation array
"""
return jnp.exp(-q * q * J)
[docs]
@jax.jit
def compute_fraction_jit(
t: jnp.ndarray,
f0: float,
f1: float,
f2: float,
f3: float,
) -> jnp.ndarray:
"""JIT-compiled sample fraction computation.
f_s(t) = f0 * exp(f1 * (t - f2)) + f3, clipped to [0, 1]
Args:
t: Time array
f0: Amplitude
f1: Exponential rate
f2: Time shift
f3: Baseline
Returns:
Fraction array in [0, 1]
"""
# ``safe_exp`` + ``smooth_clip`` preserve gradient at the [0, 1] boundary
# so NLSQ Jacobian descent does not stall when f(t) saturates (CLAUDE.md
# rule #7 — gradient-safe floors). Mirrors physics_cmc.py to keep the
# element-wise and meshgrid paths bit-equivalent.
fraction = f0 * safe_exp(f1 * (t - f2)) + f3
return smooth_clip(fraction, 0.0, 1.0)
[docs]
@jax.jit
def compute_velocity_integral_matrix(
t: jnp.ndarray,
v0: float,
beta: float,
v_offset: float,
dt: float,
) -> jnp.ndarray:
"""JIT-compiled velocity integral matrix (NLSQ meshgrid path).
Computes M[i,j] = ∫_{t_i}^{t_j} v(t') dt'
where v(t) = v0 * t^beta + v_offset
Uses shared ``trapezoid_cumsum`` → ``create_time_integral_matrix``
pipeline for O(N) efficiency and O(dt²) accuracy.
The velocity integral is *signed* (not absolute-valued) because it
feeds into the phase factor ``cos(q cos(φ) ∫v dt)``.
Args:
t: Time array, shape (N,)
v0: Velocity prefactor
beta: Velocity exponent
v_offset: Velocity offset
dt: Time step
Returns:
Signed integral matrix, shape (N, N)
"""
velocity = compute_velocity_rate(t, v0, beta, v_offset)
cumsum = trapezoid_cumsum(velocity, dt)
return create_time_integral_matrix(cumsum)
[docs]
@jax.jit
def compute_transport_integral_matrix(
t: jnp.ndarray,
D0: float,
alpha: float,
offset: float,
dt: float,
) -> jnp.ndarray:
"""JIT-compiled transport integral matrix (NLSQ meshgrid path).
Computes ``M[i,j] = |∫_{t_i}^{t_j} J_rate(t') dt'|``
where J_rate(t) = D0 * t^alpha + offset
Uses shared ``compute_transport_rate`` → ``trapezoid_cumsum`` →
``create_time_integral_matrix`` → ``smooth_abs`` pipeline.
Args:
t: Time array, shape (N,)
D0: Transport prefactor
alpha: Transport exponent
offset: Transport rate offset
dt: Time step
Returns:
Transport integral matrix, shape (N, N)
"""
J_rate = compute_transport_rate(t, D0, alpha, offset)
cumsum = trapezoid_cumsum(J_rate, dt)
diff = create_time_integral_matrix(cumsum)
return smooth_abs(diff)
[docs]
def compute_c2_heterodyne(
params: jnp.ndarray,
t: jnp.ndarray,
q: float | jnp.ndarray,
dt: float | jnp.ndarray,
phi_angle: float | jnp.ndarray,
contrast: float | jnp.ndarray = 1.0,
offset: float | jnp.ndarray = 1.0,
) -> jnp.ndarray:
"""JIT-compiled two-time heterodyne correlation (meshgrid path).
Thin shim around :func:`heterodyne.core.physics_kernel.compute_c2_unified`
with ``eval_strategy="meshgrid"``. Codex/Gemini G1: the physics math
lives in the unified kernel; this function preserves the legacy import
path for all NLSQ call sites (``compute_residuals``, ``compute_jacobian``,
multi-angle stratified fits, etc.) without behavioural change.
Args:
params: Parameter array of shape ``(14,)`` in canonical order:
``[D0_ref, alpha_ref, D_offset_ref, D0_sample, alpha_sample,
D_offset_sample, v0, beta, v_offset, f0, f1, f2, f3, phi0]``.
t: Time array, shape (N,)
q: Scattering wavevector magnitude
dt: Time step
phi_angle: Detector phi angle (degrees)
contrast: Speckle contrast (beta), default 1.0
offset: Baseline offset, default 1.0
Returns:
Correlation matrix c2, shape (N, N).
"""
# Local import to avoid a cycle: physics_kernel.py imports nothing from
# this module, but importing at module load would chain into JAX init
# before some downstream test helpers expect it.
from heterodyne.core.physics_kernel import compute_c2_unified
return compute_c2_unified( # type: ignore[no-any-return]
params,
q,
dt,
phi_angle,
contrast,
offset,
eval_strategy="meshgrid",
t=t,
)
[docs]
def compute_c2_heterodyne_multiphi(
params: jnp.ndarray,
t: jnp.ndarray,
q: float | jnp.ndarray,
dt: float | jnp.ndarray,
phi_unique: jnp.ndarray,
contrast_arr: jnp.ndarray,
offset_arr: jnp.ndarray,
) -> jnp.ndarray:
"""Joint multi-phi c2 evaluator (vmap reference path — not used in CMC hot path).
Vmap wrapper over :func:`compute_c2_heterodyne` that evaluates the
heterodyne two-component model for every angle in ``phi_unique`` with
the matching per-angle ``contrast_arr`` / ``offset_arr``. Returns a
stacked tensor that can be gathered by ``(phi_indices, i1, i2)`` to
recover c2 at each pooled (t1, t2, phi) tuple.
.. note::
The joint multi-phi CMC likelihood (Phase 4) now calls
:func:`compute_c2_heterodyne_pooled` directly, which computes c2
at the pooled points without ever materializing the
``(n_phi, N, N)`` stack this function returns. This helper is
retained as a vmap reference path for parity tests
(``tests/regression/test_cmc_pooled_kernel_parity.py``) and ad-hoc
diagnostics.
Args:
params: Parameter array of shape ``(14,)`` (same canonical order as
``compute_c2_heterodyne``).
t: Time grid array, shape ``(N,)``.
q: Scattering wavevector magnitude.
dt: Time step.
phi_unique: Unique phi angles, shape ``(n_phi,)``.
contrast_arr: Per-angle contrast values, shape ``(n_phi,)``.
offset_arr: Per-angle offset values, shape ``(n_phi,)``.
Returns:
Stacked correlation, shape ``(n_phi, N, N)``. ``c2[i, j, k]`` is the
model at phi=phi_unique[i], t1=t[j], t2=t[k] with contrast=contrast_arr[i],
offset=offset_arr[i].
"""
def _single(
phi: jnp.ndarray, contrast: jnp.ndarray, offset: jnp.ndarray
) -> jnp.ndarray:
return compute_c2_heterodyne(params, t, q, dt, phi, contrast, offset)
return jax.vmap(_single, in_axes=(0, 0, 0))(phi_unique, contrast_arr, offset_arr)
[docs]
def compute_c2_heterodyne_pooled(
params: jnp.ndarray,
t: jnp.ndarray,
q: float | jnp.ndarray,
dt: float | jnp.ndarray,
idx1: jnp.ndarray,
idx2: jnp.ndarray,
phi_indices: jnp.ndarray,
phi_unique: jnp.ndarray,
contrast_arr: jnp.ndarray,
offset_arr: jnp.ndarray,
) -> jnp.ndarray:
"""Pooled-data c2 evaluator for joint multi-phi CMC (homodyne parity).
Thin shim around :func:`heterodyne.core.physics_kernel.compute_c2_unified`
with ``eval_strategy="pooled"``. Computes c2 directly at every pooled
``(phi_indices[k], idx1[k], idx2[k])`` tuple — never materializes the
``(n_phi, N, N)`` stack that the vmap+gather path required. For
N=1000, n_phi=4 this eliminates a ~256 MB float64 intermediate at every
NUTS leapfrog step.
Phase 4 of the joint multi-phi CMC refactor — replaces the per-step
materialise-then-gather pattern in
:func:`heterodyne.optimization.cmc.model._heterodyne_pooled_likelihood`.
Args:
params: Parameter array of shape ``(14,)`` (same canonical order as
:func:`compute_c2_heterodyne`).
t: Time grid array, shape ``(N,)``.
q: Scattering wavevector magnitude.
dt: Time step.
idx1: Per-pooled-point index into ``t`` for the first time
coordinate, shape ``(n_total,)``. Typically built via
``np.searchsorted(t, t1_flat)``.
idx2: Per-pooled-point index into ``t`` for the second time
coordinate, shape ``(n_total,)``.
phi_indices: Per-pooled-point index into ``phi_unique``, shape
``(n_total,)``.
phi_unique: Sorted unique phi angles, shape ``(n_phi,)``.
contrast_arr: Per-angle contrast values, shape ``(n_phi,)``.
offset_arr: Per-angle offset values, shape ``(n_phi,)``.
Returns:
Pooled correlation values, shape ``(n_total,)``. ``c2[k]`` is the
model at ``phi=phi_unique[phi_indices[k]]``, ``t1=t[idx1[k]]``,
``t2=t[idx2[k]]`` with the matching per-angle scaling.
"""
# Local import to avoid a cycle: physics_kernel.py imports nothing from
# this module, but importing at module load would chain into JAX init
# before some downstream test helpers expect it. Same pattern as
# ``compute_c2_heterodyne`` above.
from heterodyne.core.physics_kernel import compute_c2_unified
return compute_c2_unified( # type: ignore[no-any-return]
params,
q,
dt,
eval_strategy="pooled",
time_grid=t,
idx1=idx1,
idx2=idx2,
phi_unique=phi_unique,
phi_indices=phi_indices,
contrast_arr=contrast_arr,
offset_arr=offset_arr,
)
[docs]
def compute_residuals(
params: jnp.ndarray,
t: jnp.ndarray,
q: float,
dt: float,
phi_angle: float,
c2_data: jnp.ndarray,
weights: jnp.ndarray | None = None,
contrast: float = 1.0,
offset: float = 1.0,
) -> jnp.ndarray:
"""Compute weighted residuals between model and data.
Args:
params: Parameter array, shape (14,)
t: Time array
q: Scattering wavevector
dt: Time step
phi_angle: Detector phi angle
c2_data: Experimental correlation data
weights: Optional weights (1/uncertainty²)
contrast: Speckle contrast (beta), default 1.0
offset: Baseline offset, default 1.0
Returns:
Flattened residual array
"""
if weights is None:
weights = jnp.ones_like(c2_data)
return _compute_residuals_jit( # type: ignore[no-any-return]
params, t, q, dt, phi_angle, c2_data, weights, contrast, offset
)
@jax.jit
def _compute_residuals_jit(
params: jnp.ndarray,
t: jnp.ndarray,
q: float,
dt: float,
phi_angle: float,
c2_data: jnp.ndarray,
weights: jnp.ndarray,
contrast: float,
offset: float,
) -> jnp.ndarray:
"""JIT-compiled residuals computation (always receives weights).
Two boundary exclusions are applied at residual construction (not via
data truncation, which would shorten the model's time grid and break
viz/NPZ shape parity):
1. The t=0 row ``(0, j)`` and t=0 column ``(i, 0)`` are zeroed. The
first frame holds the correlator's raw output; the model evaluates
cleanly at t=0 (``g1(0,0)=1 ⇒ c2 = offset + contrast``), but the
experimental boundary is not used in chi-square fitting.
2. The diagonal ``t1==t2`` is excluded (homodyne parity): corrected
diagonal values are interpolated estimates, not real physics.
The returned vector has shape ``(n_time-1) * (n_time-2)`` — only
off-diagonal pairs where both row > 0 and col > 0 (t=0 boundary
excluded). This gives correct DOF accounting: zero-padded boundary
entries are absent from the array, not just zeroed in-place.
"""
c2_model = compute_c2_heterodyne(params, t, q, dt, phi_angle, contrast, offset)
n_time = c2_data.shape[0]
indices = jnp.arange(n_time)
boundary_mask = (indices[:, None] > 0) & (indices[None, :] > 0)
residuals = (c2_model - c2_data) * jnp.sqrt(weights)
valid_mask = boundary_mask & ~jnp.eye(n_time, dtype=bool)
rows, cols = jnp.nonzero(valid_mask, size=(n_time - 1) * (n_time - 2))
return residuals[rows, cols]
# Jacobian of residuals with respect to parameters (for NLSQ).
# jacfwd (forward-mode) does 14 JVP passes for 14 parameters,
# whereas jacobian (reverse-mode) would do ~N² backward passes
# (one per residual element). For the XPCS use-case with N=200-500
# this is ~8,900x cheaper at N=500 (125K residuals vs 14 params).
_compute_residuals_jacobian_jit = jax.jit(jax.jacfwd(_compute_residuals_jit, argnums=0))
[docs]
@jax.jit
def compute_chi_squared(
params: jnp.ndarray,
t: jnp.ndarray,
q: float,
dt: float,
phi_angle: float,
c2_data: jnp.ndarray,
weights: jnp.ndarray,
contrast: float,
offset: float,
) -> jnp.ndarray:
"""JIT-compiled chi-squared computation.
chi² = sum((c2_model - c2_data)² × weights)
Args:
params: Parameter array, shape (14,)
t: Time array
q: Scattering wavevector
dt: Time step
phi_angle: Detector phi angle
c2_data: Experimental correlation data
weights: Weights (1/uncertainty²)
contrast: Speckle contrast
offset: Baseline offset
Returns:
Chi-squared scalar
"""
c2_model = compute_c2_heterodyne(params, t, q, dt, phi_angle, contrast, offset)
n_time = c2_data.shape[0]
indices = jnp.arange(n_time)
boundary_mask = (indices[:, None] > 0) & (indices[None, :] > 0)
return jnp.sum(
(c2_model - c2_data) ** 2 * weights * boundary_mask.astype(c2_model.dtype)
)
[docs]
def batch_chi_squared(
params_batch: jnp.ndarray,
t: jnp.ndarray,
q: float,
dt: float,
phi_angle: float,
c2_data: jnp.ndarray,
weights: jnp.ndarray,
contrast: float = 1.0,
offset: float = 1.0,
chunk_size: int | None = None,
) -> jnp.ndarray:
"""Vectorized chi-squared over a batch of parameter sets.
Uses ``jax.vmap`` for efficient parallel evaluation. For large batches
or large time grids, ``chunk_size`` limits simultaneous N×N allocations
to prevent XLA memory exhaustion (each vmap'd evaluation allocates
multiple N×N intermediate matrices).
Args:
params_batch: Parameter sets, shape ``(n_sets, 14)``.
t: Time array, shape ``(N,)``.
q: Scattering wavevector.
dt: Time step.
phi_angle: Detector phi angle.
c2_data: Experimental data.
weights: Weights.
contrast: Speckle contrast.
offset: Baseline offset.
chunk_size: Max batch elements to vmap simultaneously. ``None``
(default) auto-selects based on time-grid size: ``max(1,
200 // (N // 100))`` to keep peak memory under ~1.6 GB.
Returns:
Chi-squared values, shape ``(n_sets,)``.
"""
n_sets = params_batch.shape[0]
n_times = t.shape[0]
def single_chi2(params: jnp.ndarray) -> jnp.ndarray:
return compute_chi_squared( # type: ignore[no-any-return]
params, t, q, dt, phi_angle, c2_data, weights, contrast, offset
)
if chunk_size is None:
# Auto-select: each evaluation creates ~12 N×N float64 matrices
# (half_tr, cumsum, integral matrix, ref/sample/cross terms, plus
# XLA intermediates) → ~96 N² bytes per evaluation.
# Target peak ≈ 1.6 GB → chunk_size ≈ 1.6e9 / (96 * N²).
matrix_bytes = 96 * n_times * n_times
chunk_size = max(1, int(1.6e9 / max(matrix_bytes, 1)))
if n_sets <= chunk_size:
return jax.vmap(single_chi2)(params_batch)
# Chunked evaluation to bound peak memory
chunks = []
for start in range(0, n_sets, chunk_size):
chunk = params_batch[start : start + chunk_size]
chunks.append(jax.vmap(single_chi2)(chunk))
return jnp.concatenate(chunks)
[docs]
@jax.jit
def compute_multi_angle_residuals(
params: jnp.ndarray,
t: jnp.ndarray,
q: float,
dt: float,
phi_angles: jnp.ndarray,
c2_data_batch: jnp.ndarray,
weights_batch: jnp.ndarray,
contrasts: jnp.ndarray,
offsets: jnp.ndarray,
) -> jnp.ndarray:
"""JIT-compiled residuals for multiple phi angles simultaneously.
Args:
params: Parameter array, shape (14,)
t: Time array, shape (N,)
q: Scattering wavevector
dt: Time step
phi_angles: Phi angles, shape (n_phi,)
c2_data_batch: Experimental data, shape (n_phi, N, N)
weights_batch: Weights, shape (n_phi, N, N)
contrasts: Per-angle contrasts, shape (n_phi,)
offsets: Per-angle offsets, shape (n_phi,)
Returns:
Stacked flattened residuals, shape (n_phi × (N-1) × (N-2),)
"""
def single_angle_residual(
phi: jnp.ndarray,
c2_exp: jnp.ndarray,
w: jnp.ndarray,
c: jnp.ndarray,
o: jnp.ndarray,
) -> jnp.ndarray:
# Match _compute_residuals_jit: apply BOTH the t=0 boundary mask
# and the diagonal exclusion before flattening. Previously this
# path only excluded the diagonal, so joint multi-phi fits used
# the t=0 row/col while single-phi fits did not, silently
# changing the chi-square support between the two code paths.
c2_model = compute_c2_heterodyne(params, t, q, dt, phi, c, o)
n_time = c2_exp.shape[0]
indices = jnp.arange(n_time)
boundary_mask = (indices[:, None] > 0) & (indices[None, :] > 0)
residuals = (c2_model - c2_exp) * jnp.sqrt(w)
valid_mask = boundary_mask & ~jnp.eye(n_time, dtype=bool)
rows, cols = jnp.nonzero(valid_mask, size=(n_time - 1) * (n_time - 2))
return residuals[rows, cols]
compute_all = jax.vmap(single_angle_residual, in_axes=(0, 0, 0, 0, 0))
residuals_batch = compute_all(
phi_angles, c2_data_batch, weights_batch, contrasts, offsets
)
return residuals_batch.ravel()
# Gradient of chi-squared with respect to parameters
compute_chi_squared_grad = jax.jit(jax.grad(compute_chi_squared, argnums=0))
# Hessian of chi-squared (for uncertainty estimation).
# Forward-over-reverse (jacfwd ∘ grad) is preferred over hessian()
# (reverse-over-reverse) for a (14,14) output: it runs 14 JVP passes
# over the gradient graph rather than 14 backward passes over 14
# backward passes, giving a ~14x reduction in graph size on CPU.
compute_chi_squared_hessian = jax.jit(
jax.jacfwd(jax.grad(compute_chi_squared, argnums=0), argnums=0)
)
[docs]
def compute_residuals_jacobian(
params: jnp.ndarray,
t: jnp.ndarray,
q: float,
dt: float,
phi_angle: float,
c2_data: jnp.ndarray,
weights: jnp.ndarray | None = None,
contrast: float = 1.0,
offset: float = 1.0,
) -> jnp.ndarray:
"""Compute Jacobian of residuals with respect to parameters.
Args:
params: Parameter array, shape (14,)
t: Time array
q: Scattering wavevector
dt: Time step
phi_angle: Detector phi angle
c2_data: Experimental correlation data
weights: Optional weights (1/uncertainty²)
contrast: Speckle contrast (beta), default 1.0
offset: Baseline offset, default 1.0
Returns:
Jacobian matrix
"""
if weights is None:
weights = jnp.ones_like(c2_data)
return _compute_residuals_jacobian_jit( # type: ignore[no-any-return]
params, t, q, dt, phi_angle, c2_data, weights, contrast, offset
)