JAX Backend¶
CPU-accelerated correlation computation via the NLSQ meshgrid path. Functions in this module build full N x N integral matrices using cumulative trapezoidal sums and are designed for JIT compilation.
JAX-accelerated computational backend for heterodyne correlation.
This module provides JIT-compiled functions for computing the two-component heterodyne correlation function using the integral formulation (PNAS Eq. S-95). All functions are designed to be stateless and compatible with JAX transformations (jit, vmap, grad).
The correlation is computed as
c2 = offset + contrast × [ref + sample + cross] / f²,
where transport terms use the integral of the rate J(t):
half_tr[i,j] = exp(-½q² × |∫ J_rate(t') dt'|).
The 14 model parameters in canonical order are D0_ref, alpha_ref, D_offset_ref, D0_sample, alpha_sample, D_offset_sample, v0, beta, v_offset, f0, f1, f2, f3, phi0.
Three evaluation strategies are exposed (all dispatch through
heterodyne.core.physics_kernel.compute_c2_unified()):
compute_c2_heterodyne()— full(N, N)meshgrid output for NLSQ.compute_c2_elementwise(re-exported fromheterodyne.core.physics_cmc) — sharded per-pair(n_pairs,)output for CMC.compute_c2_heterodyne_pooled()— per-pooled-point(n_total,)output for joint multi-phi CMC (Phase 4; replaces the older vmap+gather pattern incompute_c2_heterodyne_multiphi()).
- heterodyne.core.jax_backend.compute_transport_jit(t, D0, alpha, offset, n_times)[source]
JIT-compiled pointwise transport coefficient computation.
J(t) = D0 * t^alpha + offset
Deprecated since version Pointwise: approximation — not used in production correlation. Production code uses compute_transport_integral_matrix for the integral formulation (PNAS Eq. S-95). Retained for test compatibility and 1D visualization helpers.
- heterodyne.core.jax_backend.compute_g1_transport(J, q)[source]
JIT-compiled pointwise g1 correlation from transport coefficient.
g1(t) = exp(-q² * J(t))
Deprecated since version Pointwise: approximation — not used in production correlation. Production code uses the integral formulation via compute_transport_integral_matrix. Retained for test compatibility and 1D visualization helpers.
- heterodyne.core.jax_backend.compute_fraction_jit(t, f0, f1, f2, f3)[source]
JIT-compiled sample fraction computation.
f_s(t) = f0 * exp(f1 * (t - f2)) + f3, clipped to [0, 1]
- heterodyne.core.jax_backend.compute_velocity_integral_matrix(t, v0, beta, v_offset, dt)[source]
JIT-compiled velocity integral matrix (NLSQ meshgrid path).
Computes M[i,j] = ∫_{t_i}^{t_j} v(t’) dt’ where v(t) = v0 * t^beta + v_offset
Uses shared
trapezoid_cumsum→create_time_integral_matrixpipeline for O(N) efficiency and O(dt²) accuracy. The velocity integral is signed (not absolute-valued) because it feeds into the phase factorcos(q cos(φ) ∫v dt).
- heterodyne.core.jax_backend.compute_transport_integral_matrix(t, D0, alpha, offset, dt)[source]
JIT-compiled transport integral matrix (NLSQ meshgrid path).
Computes
M[i,j] = |∫_{t_i}^{t_j} J_rate(t') dt'|where J_rate(t) = D0 * t^alpha + offsetUses shared
compute_transport_rate→trapezoid_cumsum→create_time_integral_matrix→smooth_abspipeline.
- heterodyne.core.jax_backend.compute_c2_heterodyne(params, t, q, dt, phi_angle, contrast=1.0, offset=1.0)[source]
JIT-compiled two-time heterodyne correlation (meshgrid path).
Thin shim around
heterodyne.core.physics_kernel.compute_c2_unified()witheval_strategy="meshgrid". Codex/Gemini G1: the physics math lives in the unified kernel; this function preserves the legacy import path for all NLSQ call sites (compute_residuals,compute_jacobian, multi-angle stratified fits, etc.) without behavioural change.- Parameters:
- Return type:
- Returns:
Correlation matrix c2, shape (N, N).
- heterodyne.core.jax_backend.compute_c2_heterodyne_multiphi(params, t, q, dt, phi_unique, contrast_arr, offset_arr)[source]
Joint multi-phi c2 evaluator (vmap reference path — not used in CMC hot path).
Vmap wrapper over
compute_c2_heterodyne()that evaluates the heterodyne two-component model for every angle inphi_uniquewith the matching per-anglecontrast_arr/offset_arr. Returns a stacked tensor that can be gathered by(phi_indices, i1, i2)to recover c2 at each pooled (t1, t2, phi) tuple.Note
The joint multi-phi CMC likelihood (Phase 4) now calls
compute_c2_heterodyne_pooled()directly, which computes c2 at the pooled points without ever materializing the(n_phi, N, N)stack this function returns. This helper is retained as a vmap reference path for parity tests (tests/regression/test_cmc_pooled_kernel_parity.py) and ad-hoc diagnostics.- Parameters:
params (
Array) – Parameter array of shape(14,)(same canonical order ascompute_c2_heterodyne).t (
Array) – Time grid array, shape(N,).phi_unique (
Array) – Unique phi angles, shape(n_phi,).contrast_arr (
Array) – Per-angle contrast values, shape(n_phi,).offset_arr (
Array) – Per-angle offset values, shape(n_phi,).
- Return type:
- Returns:
Stacked correlation, shape
(n_phi, N, N).c2[i, j, k]is the model at phi=phi_unique[i], t1=t[j], t2=t[k] with contrast=contrast_arr[i], offset=offset_arr[i].
- heterodyne.core.jax_backend.compute_c2_heterodyne_pooled(params, t, q, dt, idx1, idx2, phi_indices, phi_unique, contrast_arr, offset_arr)[source]
Pooled-data c2 evaluator for joint multi-phi CMC (homodyne parity).
Thin shim around
heterodyne.core.physics_kernel.compute_c2_unified()witheval_strategy="pooled". Computes c2 directly at every pooled(phi_indices[k], idx1[k], idx2[k])tuple — never materializes the(n_phi, N, N)stack that the vmap+gather path required. For N=1000, n_phi=4 this eliminates a ~256 MB float64 intermediate at every NUTS leapfrog step.Phase 4 of the joint multi-phi CMC refactor — replaces the per-step materialise-then-gather pattern in
heterodyne.optimization.cmc.model._heterodyne_pooled_likelihood().- Parameters:
params (
Array) – Parameter array of shape(14,)(same canonical order ascompute_c2_heterodyne()).t (
Array) – Time grid array, shape(N,).idx1 (
Array) – Per-pooled-point index intotfor the first time coordinate, shape(n_total,). Typically built vianp.searchsorted(t, t1_flat).idx2 (
Array) – Per-pooled-point index intotfor the second time coordinate, shape(n_total,).phi_indices (
Array) – Per-pooled-point index intophi_unique, shape(n_total,).phi_unique (
Array) – Sorted unique phi angles, shape(n_phi,).contrast_arr (
Array) – Per-angle contrast values, shape(n_phi,).offset_arr (
Array) – Per-angle offset values, shape(n_phi,).
- Return type:
- Returns:
Pooled correlation values, shape
(n_total,).c2[k]is the model atphi=phi_unique[phi_indices[k]],t1=t[idx1[k]],t2=t[idx2[k]]with the matching per-angle scaling.
- heterodyne.core.jax_backend.compute_residuals(params, t, q, dt, phi_angle, c2_data, weights=None, contrast=1.0, offset=1.0)[source]
Compute weighted residuals between model and data.
- Parameters:
params (
Array) – Parameter array, shape (14,)t (
Array) – Time arrayq (
float) – Scattering wavevectordt (
float) – Time stepphi_angle (
float) – Detector phi anglec2_data (
Array) – Experimental correlation datacontrast (
float) – Speckle contrast (beta), default 1.0offset (
float) – Baseline offset, default 1.0
- Return type:
- Returns:
Flattened residual array
- heterodyne.core.jax_backend.compute_chi_squared(params, t, q, dt, phi_angle, c2_data, weights, contrast, offset)[source]
JIT-compiled chi-squared computation.
chi² = sum((c2_model - c2_data)² × weights)
- Parameters:
params (
Array) – Parameter array, shape (14,)t (
Array) – Time arrayq (
float) – Scattering wavevectordt (
float) – Time stepphi_angle (
float) – Detector phi anglec2_data (
Array) – Experimental correlation dataweights (
Array) – Weights (1/uncertainty²)contrast (
float) – Speckle contrastoffset (
float) – Baseline offset
- Return type:
- Returns:
Chi-squared scalar
- heterodyne.core.jax_backend.batch_chi_squared(params_batch, t, q, dt, phi_angle, c2_data, weights, contrast=1.0, offset=1.0, chunk_size=None)[source]
Vectorized chi-squared over a batch of parameter sets.
Uses
jax.vmapfor efficient parallel evaluation. For large batches or large time grids,chunk_sizelimits simultaneous N×N allocations to prevent XLA memory exhaustion (each vmap’d evaluation allocates multiple N×N intermediate matrices).- Parameters:
params_batch (
Array) – Parameter sets, shape(n_sets, 14).t (
Array) – Time array, shape(N,).q (
float) – Scattering wavevector.dt (
float) – Time step.phi_angle (
float) – Detector phi angle.c2_data (
Array) – Experimental data.weights (
Array) – Weights.contrast (
float) – Speckle contrast.offset (
float) – Baseline offset.chunk_size (
int|None) – Max batch elements to vmap simultaneously.None(default) auto-selects based on time-grid size:max(1, 200 // (N // 100))to keep peak memory under ~1.6 GB.
- Return type:
- Returns:
Chi-squared values, shape
(n_sets,).
- heterodyne.core.jax_backend.compute_multi_angle_residuals(params, t, q, dt, phi_angles, c2_data_batch, weights_batch, contrasts, offsets)[source]
JIT-compiled residuals for multiple phi angles simultaneously.
- Parameters:
params (
Array) – Parameter array, shape (14,)t (
Array) – Time array, shape (N,)q (
float) – Scattering wavevectordt (
float) – Time stepphi_angles (
Array) – Phi angles, shape (n_phi,)c2_data_batch (
Array) – Experimental data, shape (n_phi, N, N)weights_batch (
Array) – Weights, shape (n_phi, N, N)contrasts (
Array) – Per-angle contrasts, shape (n_phi,)offsets (
Array) – Per-angle offsets, shape (n_phi,)
- Return type:
- Returns:
Stacked flattened residuals, shape (n_phi × (N-1) × (N-2),)
- heterodyne.core.jax_backend.compute_chi_squared_grad(params, t, q, dt, phi_angle, c2_data, weights, contrast, offset)
Gradient of compute_chi_squared with respect to positional argument(s) 0. Takes the same arguments as compute_chi_squared but returns the gradient, which has the same shape as the arguments at positions 0.
- Return type:
- heterodyne.core.jax_backend.compute_chi_squared_hessian(params, t, q, dt, phi_angle, c2_data, weights, contrast, offset)
Jacobian of compute_chi_squared with respect to positional argument(s) 0. Takes the same arguments as compute_chi_squared but returns the jacobian of the output with respect to the arguments at positions 0.
- Return type:
- heterodyne.core.jax_backend.compute_residuals_jacobian(params, t, q, dt, phi_angle, c2_data, weights=None, contrast=1.0, offset=1.0)[source]
Compute Jacobian of residuals with respect to parameters.
- Parameters:
params (
Array) – Parameter array, shape (14,)t (
Array) – Time arrayq (
float) – Scattering wavevectordt (
float) – Time stepphi_angle (
float) – Detector phi anglec2_data (
Array) – Experimental correlation datacontrast (
float) – Speckle contrast (beta), default 1.0offset (
float) – Baseline offset, default 1.0
- Return type:
- Returns:
Jacobian matrix