JAX Backend

CPU-accelerated correlation computation via the NLSQ meshgrid path. Functions in this module build full N x N integral matrices using cumulative trapezoidal sums and are designed for JIT compilation.

JAX-accelerated computational backend for heterodyne correlation.

This module provides JIT-compiled functions for computing the two-component heterodyne correlation function using the integral formulation (PNAS Eq. S-95). All functions are designed to be stateless and compatible with JAX transformations (jit, vmap, grad).

The correlation is computed as c2 = offset + contrast × [ref + sample + cross] / , where transport terms use the integral of the rate J(t): half_tr[i,j] = exp(-½q² × |∫ J_rate(t') dt'|).

The 14 model parameters in canonical order are D0_ref, alpha_ref, D_offset_ref, D0_sample, alpha_sample, D_offset_sample, v0, beta, v_offset, f0, f1, f2, f3, phi0.

Three evaluation strategies are exposed (all dispatch through heterodyne.core.physics_kernel.compute_c2_unified()):

  • compute_c2_heterodyne() — full (N, N) meshgrid output for NLSQ.

  • compute_c2_elementwise (re-exported from heterodyne.core.physics_cmc) — sharded per-pair (n_pairs,) output for CMC.

  • compute_c2_heterodyne_pooled() — per-pooled-point (n_total,) output for joint multi-phi CMC (Phase 4; replaces the older vmap+gather pattern in compute_c2_heterodyne_multiphi()).

heterodyne.core.jax_backend.compute_transport_jit(t, D0, alpha, offset, n_times)[source]

JIT-compiled pointwise transport coefficient computation.

J(t) = D0 * t^alpha + offset

Deprecated since version Pointwise: approximation — not used in production correlation. Production code uses compute_transport_integral_matrix for the integral formulation (PNAS Eq. S-95). Retained for test compatibility and 1D visualization helpers.

Parameters:
  • t (Array) – Time array

  • D0 (float) – Transport prefactor

  • alpha (float) – Transport exponent

  • offset (float) – Constant offset

  • n_times (int) – Number of time points (static for JIT)

Return type:

Array

Returns:

Transport coefficient array

heterodyne.core.jax_backend.compute_g1_transport(J, q)[source]

JIT-compiled pointwise g1 correlation from transport coefficient.

g1(t) = exp(-q² * J(t))

Deprecated since version Pointwise: approximation — not used in production correlation. Production code uses the integral formulation via compute_transport_integral_matrix. Retained for test compatibility and 1D visualization helpers.

Parameters:
  • J (Array) – Transport coefficient array

  • q (float) – Scattering wavevector

Return type:

Array

Returns:

g1 correlation array

heterodyne.core.jax_backend.compute_fraction_jit(t, f0, f1, f2, f3)[source]

JIT-compiled sample fraction computation.

f_s(t) = f0 * exp(f1 * (t - f2)) + f3, clipped to [0, 1]

Parameters:
  • t (Array) – Time array

  • f0 (float) – Amplitude

  • f1 (float) – Exponential rate

  • f2 (float) – Time shift

  • f3 (float) – Baseline

Return type:

Array

Returns:

Fraction array in [0, 1]

heterodyne.core.jax_backend.compute_velocity_integral_matrix(t, v0, beta, v_offset, dt)[source]

JIT-compiled velocity integral matrix (NLSQ meshgrid path).

Computes M[i,j] = ∫_{t_i}^{t_j} v(t’) dt’ where v(t) = v0 * t^beta + v_offset

Uses shared trapezoid_cumsumcreate_time_integral_matrix pipeline for O(N) efficiency and O(dt²) accuracy. The velocity integral is signed (not absolute-valued) because it feeds into the phase factor cos(q cos(φ) ∫v dt).

Parameters:
  • t (Array) – Time array, shape (N,)

  • v0 (float) – Velocity prefactor

  • beta (float) – Velocity exponent

  • v_offset (float) – Velocity offset

  • dt (float) – Time step

Return type:

Array

Returns:

Signed integral matrix, shape (N, N)

heterodyne.core.jax_backend.compute_transport_integral_matrix(t, D0, alpha, offset, dt)[source]

JIT-compiled transport integral matrix (NLSQ meshgrid path).

Computes M[i,j] = |∫_{t_i}^{t_j} J_rate(t') dt'| where J_rate(t) = D0 * t^alpha + offset

Uses shared compute_transport_ratetrapezoid_cumsumcreate_time_integral_matrixsmooth_abs pipeline.

Parameters:
  • t (Array) – Time array, shape (N,)

  • D0 (float) – Transport prefactor

  • alpha (float) – Transport exponent

  • offset (float) – Transport rate offset

  • dt (float) – Time step

Return type:

Array

Returns:

Transport integral matrix, shape (N, N)

heterodyne.core.jax_backend.compute_c2_heterodyne(params, t, q, dt, phi_angle, contrast=1.0, offset=1.0)[source]

JIT-compiled two-time heterodyne correlation (meshgrid path).

Thin shim around heterodyne.core.physics_kernel.compute_c2_unified() with eval_strategy="meshgrid". Codex/Gemini G1: the physics math lives in the unified kernel; this function preserves the legacy import path for all NLSQ call sites (compute_residuals, compute_jacobian, multi-angle stratified fits, etc.) without behavioural change.

Parameters:
  • params (Array) – Parameter array of shape (14,) in canonical order: [D0_ref, alpha_ref, D_offset_ref, D0_sample, alpha_sample, D_offset_sample, v0, beta, v_offset, f0, f1, f2, f3, phi0].

  • t (Array) – Time array, shape (N,)

  • q (float | Array) – Scattering wavevector magnitude

  • dt (float | Array) – Time step

  • phi_angle (float | Array) – Detector phi angle (degrees)

  • contrast (float | Array) – Speckle contrast (beta), default 1.0

  • offset (float | Array) – Baseline offset, default 1.0

Return type:

Array

Returns:

Correlation matrix c2, shape (N, N).

heterodyne.core.jax_backend.compute_c2_heterodyne_multiphi(params, t, q, dt, phi_unique, contrast_arr, offset_arr)[source]

Joint multi-phi c2 evaluator (vmap reference path — not used in CMC hot path).

Vmap wrapper over compute_c2_heterodyne() that evaluates the heterodyne two-component model for every angle in phi_unique with the matching per-angle contrast_arr / offset_arr. Returns a stacked tensor that can be gathered by (phi_indices, i1, i2) to recover c2 at each pooled (t1, t2, phi) tuple.

Note

The joint multi-phi CMC likelihood (Phase 4) now calls compute_c2_heterodyne_pooled() directly, which computes c2 at the pooled points without ever materializing the (n_phi, N, N) stack this function returns. This helper is retained as a vmap reference path for parity tests (tests/regression/test_cmc_pooled_kernel_parity.py) and ad-hoc diagnostics.

Parameters:
  • params (Array) – Parameter array of shape (14,) (same canonical order as compute_c2_heterodyne).

  • t (Array) – Time grid array, shape (N,).

  • q (float | Array) – Scattering wavevector magnitude.

  • dt (float | Array) – Time step.

  • phi_unique (Array) – Unique phi angles, shape (n_phi,).

  • contrast_arr (Array) – Per-angle contrast values, shape (n_phi,).

  • offset_arr (Array) – Per-angle offset values, shape (n_phi,).

Return type:

Array

Returns:

Stacked correlation, shape (n_phi, N, N). c2[i, j, k] is the model at phi=phi_unique[i], t1=t[j], t2=t[k] with contrast=contrast_arr[i], offset=offset_arr[i].

heterodyne.core.jax_backend.compute_c2_heterodyne_pooled(params, t, q, dt, idx1, idx2, phi_indices, phi_unique, contrast_arr, offset_arr)[source]

Pooled-data c2 evaluator for joint multi-phi CMC (homodyne parity).

Thin shim around heterodyne.core.physics_kernel.compute_c2_unified() with eval_strategy="pooled". Computes c2 directly at every pooled (phi_indices[k], idx1[k], idx2[k]) tuple — never materializes the (n_phi, N, N) stack that the vmap+gather path required. For N=1000, n_phi=4 this eliminates a ~256 MB float64 intermediate at every NUTS leapfrog step.

Phase 4 of the joint multi-phi CMC refactor — replaces the per-step materialise-then-gather pattern in heterodyne.optimization.cmc.model._heterodyne_pooled_likelihood().

Parameters:
  • params (Array) – Parameter array of shape (14,) (same canonical order as compute_c2_heterodyne()).

  • t (Array) – Time grid array, shape (N,).

  • q (float | Array) – Scattering wavevector magnitude.

  • dt (float | Array) – Time step.

  • idx1 (Array) – Per-pooled-point index into t for the first time coordinate, shape (n_total,). Typically built via np.searchsorted(t, t1_flat).

  • idx2 (Array) – Per-pooled-point index into t for the second time coordinate, shape (n_total,).

  • phi_indices (Array) – Per-pooled-point index into phi_unique, shape (n_total,).

  • phi_unique (Array) – Sorted unique phi angles, shape (n_phi,).

  • contrast_arr (Array) – Per-angle contrast values, shape (n_phi,).

  • offset_arr (Array) – Per-angle offset values, shape (n_phi,).

Return type:

Array

Returns:

Pooled correlation values, shape (n_total,). c2[k] is the model at phi=phi_unique[phi_indices[k]], t1=t[idx1[k]], t2=t[idx2[k]] with the matching per-angle scaling.

heterodyne.core.jax_backend.compute_residuals(params, t, q, dt, phi_angle, c2_data, weights=None, contrast=1.0, offset=1.0)[source]

Compute weighted residuals between model and data.

Parameters:
  • params (Array) – Parameter array, shape (14,)

  • t (Array) – Time array

  • q (float) – Scattering wavevector

  • dt (float) – Time step

  • phi_angle (float) – Detector phi angle

  • c2_data (Array) – Experimental correlation data

  • weights (Array | None) – Optional weights (1/uncertainty²)

  • contrast (float) – Speckle contrast (beta), default 1.0

  • offset (float) – Baseline offset, default 1.0

Return type:

Array

Returns:

Flattened residual array

heterodyne.core.jax_backend.compute_chi_squared(params, t, q, dt, phi_angle, c2_data, weights, contrast, offset)[source]

JIT-compiled chi-squared computation.

chi² = sum((c2_model - c2_data)² × weights)

Parameters:
  • params (Array) – Parameter array, shape (14,)

  • t (Array) – Time array

  • q (float) – Scattering wavevector

  • dt (float) – Time step

  • phi_angle (float) – Detector phi angle

  • c2_data (Array) – Experimental correlation data

  • weights (Array) – Weights (1/uncertainty²)

  • contrast (float) – Speckle contrast

  • offset (float) – Baseline offset

Return type:

Array

Returns:

Chi-squared scalar

heterodyne.core.jax_backend.batch_chi_squared(params_batch, t, q, dt, phi_angle, c2_data, weights, contrast=1.0, offset=1.0, chunk_size=None)[source]

Vectorized chi-squared over a batch of parameter sets.

Uses jax.vmap for efficient parallel evaluation. For large batches or large time grids, chunk_size limits simultaneous N×N allocations to prevent XLA memory exhaustion (each vmap’d evaluation allocates multiple N×N intermediate matrices).

Parameters:
  • params_batch (Array) – Parameter sets, shape (n_sets, 14).

  • t (Array) – Time array, shape (N,).

  • q (float) – Scattering wavevector.

  • dt (float) – Time step.

  • phi_angle (float) – Detector phi angle.

  • c2_data (Array) – Experimental data.

  • weights (Array) – Weights.

  • contrast (float) – Speckle contrast.

  • offset (float) – Baseline offset.

  • chunk_size (int | None) – Max batch elements to vmap simultaneously. None (default) auto-selects based on time-grid size: max(1, 200 // (N // 100)) to keep peak memory under ~1.6 GB.

Return type:

Array

Returns:

Chi-squared values, shape (n_sets,).

heterodyne.core.jax_backend.compute_multi_angle_residuals(params, t, q, dt, phi_angles, c2_data_batch, weights_batch, contrasts, offsets)[source]

JIT-compiled residuals for multiple phi angles simultaneously.

Parameters:
  • params (Array) – Parameter array, shape (14,)

  • t (Array) – Time array, shape (N,)

  • q (float) – Scattering wavevector

  • dt (float) – Time step

  • phi_angles (Array) – Phi angles, shape (n_phi,)

  • c2_data_batch (Array) – Experimental data, shape (n_phi, N, N)

  • weights_batch (Array) – Weights, shape (n_phi, N, N)

  • contrasts (Array) – Per-angle contrasts, shape (n_phi,)

  • offsets (Array) – Per-angle offsets, shape (n_phi,)

Return type:

Array

Returns:

Stacked flattened residuals, shape (n_phi × (N-1) × (N-2),)

heterodyne.core.jax_backend.compute_chi_squared_grad(params, t, q, dt, phi_angle, c2_data, weights, contrast, offset)

Gradient of compute_chi_squared with respect to positional argument(s) 0. Takes the same arguments as compute_chi_squared but returns the gradient, which has the same shape as the arguments at positions 0.

Return type:

Array

heterodyne.core.jax_backend.compute_chi_squared_hessian(params, t, q, dt, phi_angle, c2_data, weights, contrast, offset)

Jacobian of compute_chi_squared with respect to positional argument(s) 0. Takes the same arguments as compute_chi_squared but returns the jacobian of the output with respect to the arguments at positions 0.

Return type:

Array

heterodyne.core.jax_backend.compute_residuals_jacobian(params, t, q, dt, phi_angle, c2_data, weights=None, contrast=1.0, offset=1.0)[source]

Compute Jacobian of residuals with respect to parameters.

Parameters:
  • params (Array) – Parameter array, shape (14,)

  • t (Array) – Time array

  • q (float) – Scattering wavevector

  • dt (float) – Time step

  • phi_angle (float) – Detector phi angle

  • c2_data (Array) – Experimental correlation data

  • weights (Array | None) – Optional weights (1/uncertainty²)

  • contrast (float) – Speckle contrast (beta), default 1.0

  • offset (float) – Baseline offset, default 1.0

Return type:

Array

Returns:

Jacobian matrix