Performance Tuning¶
The heterodyne package runs entirely on CPU and uses JAX’s XLA compiler for JIT compilation. This page covers the key levers for optimising throughput and memory usage.
XLA Configuration¶
The heterodyne-config-xla command-line tool configures XLA compiler
flags for optimal CPU performance on the current machine:
# Auto-detect and print recommended XLA flags
heterodyne-config-xla
# Apply flags to the current shell session
eval $(heterodyne-config-xla --export)
The tool detects:
CPU instruction set extensions (AVX2, AVX-512).
Number of physical cores.
NUMA topology (if applicable).
It sets XLA_FLAGS to enable the appropriate LLVM target features
and intra-op parallelism.
NUMA Awareness¶
On multi-socket systems (common in workstations and compute nodes),
NUMA-aware thread pinning prevents cross-socket memory access penalties.
The heterodyne.device.cpu module detects the NUMA topology and
configures JAX accordingly.
For manual control:
# Pin to NUMA node 0
numactl --cpunodebind=0 --membind=0 python analysis.py
This is especially important for CMC analyses where multiple NUTS chains run in parallel – each chain should ideally be bound to a single NUMA node.
Thread Count Optimisation¶
JAX and the underlying BLAS libraries use OMP_NUM_THREADS to
control parallelism. The optimal value depends on the workload:
- NLSQ fitting
Typically benefits from all available cores. Set
OMP_NUM_THREADSequal to the number of physical cores (not hyperthreads):export OMP_NUM_THREADS=16 # For a 16-core machine
- CMC sampling
When running multiple NUTS chains in parallel, each chain should use fewer threads to avoid oversubscription:
# 4 chains on a 16-core machine: 4 threads per chain export OMP_NUM_THREADS=4
- General rule
Total threads across all processes should not exceed the number of physical cores. Hyperthreading typically does not help for numerically intensive JAX workloads.
Memory Management¶
Large \(C_2\) matrices can consume significant memory. Key strategies:
Frame trimming¶
If only a subset of the measurement time is of interest, trim the frame range before fitting:
data = loader.load(frame_start=100, frame_end=500)
This reduces the \(C_2\) matrix from \(N^2\) to \((N_\text{trim})^2\) elements.
Strategy selection¶
The NLSQ strategy affects peak memory:
JIT – Highest memory (full Jacobian in XLA). Best for \(N < 500\) frames.
Chunked – Moderate memory. Good for \(500 < N < 2000\).
Sequential – Lowest memory. Use for \(N > 2000\) or memory-constrained environments.
config = NLSQConfig(strategy="chunked", chunk_size=256)
CMC Backend Selection¶
The CMC runner supports two execution modes:
- Sequential (default)
Shards are processed one at a time in the main process. Simplest and most debuggable. Use when memory is the bottleneck.
- Multiprocessing
Shards are distributed across worker processes using Python’s
multiprocessingmodule. Each worker runs its NUTS chains independently. Use when wall time is the bottleneck and the machine has enough memory for concurrent shards.
cmc_config = CMCConfig(
backend="multiprocessing",
max_workers=4, # Number of parallel shard workers
)
Profiling Tips¶
- JIT compilation time
The first call to a JIT-compiled function incurs compilation overhead. Subsequent calls with the same input shapes are fast. If compilation is slow, check that input shapes are consistent (avoid triggering recompilation with different array sizes).
- Memory profiling
Use
jax.live_arrays()or system tools (htop,psutil) to monitor JAX array allocations during fitting.- Timing
The
wall_time_secondsfield in bothNLSQResultandCMCResultreports end-to-end fitting time, excluding data loading and I/O.