ADR-001: JAX CPU-Only Backend¶
- Status:
Accepted
- Date:
2026-05-19
- Deciders:
Core team
Context¶
Heterodyne targets CPU-only JAX (no GPU/TPU support). The JAX backend is used for:
JIT compilation of the physics kernel (
jax_backend.py,physics_cmc.py).Automatic differentiation for Jacobians (NLSQ) and NUTS gradients (CMC).
Vectorized operations via
jax.vmapover angles and time points.
The CMC backend explicitly sets:
# In multiprocessing.py worker initialization
os.environ["JAX_ENABLE_X64"] = "1"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
This creates 4 virtual JAX devices per worker process, enabling NumPyro’s parallel
chain execution mode (pmap over 4 devices) without requiring a real GPU.
The physics kernel evaluates the two-component correlation model over the 14-parameter
space (D0_ref, alpha_ref, D_offset_ref, D0_sample, alpha_sample, D_offset_sample,
v0, beta, v_offset, f0, f1, f2, f3, phi0) plus 2 per-angle scaling parameters.
At double precision, the full two-time correlation matrix is O(N²) where N is the
number of time frames.
Decision¶
Use JAX on CPU as the sole numerical backend. Enforce this via:
# heterodyne/__init__.py
os.environ.setdefault("JAX_ENABLE_X64", "1")
os.environ.setdefault("JAX_PLATFORMS", "cpu")
# device/cpu.py
jax.config.update("jax_platform_name", "cpu")
Float64 precision is mandatory. The transport coefficient integrals
involve subtraction of nearly equal large values at small lag times, producing catastrophic cancellation in float32.
Rationale¶
Target deployment environment
Beamline workstations (APS, ESRF, PETRA III) typically run Linux on multi-core CPUs (16–128 cores) without GPU access. The analysis must run at the beamline during experiments, so installation complexity must be minimal.
Float64 requirement
JAX defaults to float32 on GPU for performance. The 14-parameter heterodyne model requires
float64 throughout: the transport coefficient integrals involve subtraction of nearly
equal values at short lag times, producing O(1) relative errors in float32. Setting
JAX_ENABLE_X64=1 enables float64, but this must be done before the first JAX
import. The heterodyne/__init__.py and cli/main.py both call
os.environ.setdefault("JAX_ENABLE_X64", "1") to ensure this.
NLSQ library boundary
The primary optimizer (nlsq.curve_fit) calls JAX JIT-compiled residual and Jacobian
functions, but the outer trust-region loop runs on the CPU host. This creates a
JAX-device-to-host round-trip per iteration. For GPU execution, this round-trip would
dominate latency (PCIe bandwidth ~16 GB/s; the Jacobian array is O(N² × 16) bytes).
The CPU round-trip is free (host memory).
CMC multiprocessing
The CMC backend spawns ⌊N_cores/2⌋ − 1 worker processes, each running NUTS on a
data shard. Each worker sets XLA_FLAGS=--xla_force_host_platform_device_count=4
to create 4 virtual devices, enabling pmap over 4 NUTS chains. This requires CPU-only
JAX; GPU workers would compete for the same physical GPU.
Determinism
CPU JAX execution is deterministic given the same seed. GPU JAX is non-deterministic
by default due to non-associative floating-point reductions in CUDA. Heterodyne’s
reproducibility guarantee (via CMCConfig.seed) requires deterministic execution.
Consequences¶
Positive:
Zero-configuration installation at beamlines:
uv syncis sufficient.Float64 precision throughout: no truncation or cancellation errors.
Deterministic results given the same seed.
Simple testing matrix: no CUDA version matrix.
Negative / Accepted trade-offs:
No GPU acceleration. The NLSQ Jacobian computation (O(N² × n_params) per iteration) is the primary bottleneck for large datasets.
CMC wall time scales with
N_shards × N_warmup_samples; typical runs take 30–120 minutes for 1000-frame datasets on a 32-core workstation.XLA_FLAGSandJAX_ENABLE_X64must be set before first JAX import; late configuration is silently ignored.
Alternatives Considered¶
A. GPU-first JAX
Would enable faster matrix operations for the NLSQ Jacobian computation and NUTS leapfrog steps. Rejected because: GPU not universally available at beamlines; 32-bit precision limitation; non-deterministic results.
B. NumPy / SciPy only
Would eliminate the JAX dependency, simplifying installation. Rejected because: no automatic differentiation (would require finite-difference Jacobians, ~100× slower); no JIT compilation; cannot leverage the NumPyro ecosystem for NUTS.
C. PyTorch backend
PyTorch has mature CPU and GPU support with autograd. Rejected because: the NumPyro
probabilistic programming library is JAX-native, and rewriting NUTS from scratch would
be a significant engineering effort; JAX’s jit/vmap/pmap composability is
a better fit for the parallelism patterns in CMC.
D. Mixed CPU/GPU
Support both, dispatch at runtime. Rejected for the current version because: doubles the testing matrix; GPU code path is untested and risky; adds complexity without clear benefit for the target deployment environment. May be reconsidered in a future version.
See also
ADR-006: No GPU Acceleration on Consumer Hardware — quantitative GPU feasibility assessment (RTX 4090)
Architecture Overview — overall system design
ADR-002: NLSQ / CMC Architectural Split — NLSQ optimizer choice
ADR-004: Consensus Monte Carlo for Bayesian Inference — CMC multiprocessing architecture