Source code for heterodyne.cli.xla_config
"""XLA configuration for JAX on CPU."""
from __future__ import annotations
import os
from typing import TYPE_CHECKING
if TYPE_CHECKING:
pass
[docs]
def configure_xla(
num_threads: int | None = None,
disable_jit: bool = False,
enable_x64: bool = True,
) -> dict[str, str]:
"""Configure XLA/JAX environment variables for CPU execution.
MUST be called before importing JAX.
Args:
num_threads: Number of CPU threads (None for auto)
disable_jit: Disable JIT compilation (for debugging)
enable_x64: Enable 64-bit float precision
Returns:
Dict of environment variables that were set
"""
env_vars = {}
# Force CPU backend
os.environ["JAX_PLATFORM_NAME"] = "cpu"
env_vars["JAX_PLATFORM_NAME"] = "cpu"
# Thread configuration
if num_threads is not None:
existing = os.environ.get("XLA_FLAGS", "")
new_flags = (
"--xla_cpu_multi_thread_eigen=true"
f" --intra_op_parallelism_threads={num_threads}"
)
if new_flags not in existing:
os.environ["XLA_FLAGS"] = f"{existing} {new_flags}".strip()
os.environ["OMP_NUM_THREADS"] = str(num_threads)
os.environ["MKL_NUM_THREADS"] = str(num_threads)
env_vars["XLA_FLAGS"] = os.environ["XLA_FLAGS"]
env_vars["OMP_NUM_THREADS"] = str(num_threads)
env_vars["MKL_NUM_THREADS"] = str(num_threads)
# Disable JIT for debugging
if disable_jit:
os.environ["JAX_DISABLE_JIT"] = "1"
env_vars["JAX_DISABLE_JIT"] = "1"
# Enable 64-bit precision
if enable_x64:
os.environ["JAX_ENABLE_X64"] = "1"
env_vars["JAX_ENABLE_X64"] = "1"
return env_vars
[docs]
def get_cpu_info() -> dict[str, int | str]:
"""Get CPU information for configuration.
Returns:
Dict with cpu_count, physical_cores, etc.
"""
import psutil
info = {
"cpu_count": psutil.cpu_count(),
"physical_cores": psutil.cpu_count(logical=False) or psutil.cpu_count(),
}
# Available memory
mem = psutil.virtual_memory()
info["available_memory_gb"] = round(mem.available / (1024**3), 1)
info["total_memory_gb"] = round(mem.total / (1024**3), 1)
return info
[docs]
def auto_configure() -> dict[str, str]:
"""Automatically configure XLA based on system resources.
Returns:
Dict of environment variables set
"""
cpu_info = get_cpu_info()
# Use physical cores (not hyperthreaded)
num_threads = cpu_info.get("physical_cores", 4)
return configure_xla(num_threads=num_threads, enable_x64=True)
[docs]
def main() -> None:
"""CLI entry point for XLA configuration."""
import argparse
parser = argparse.ArgumentParser(
description="Configure XLA for heterodyne analysis"
)
parser.add_argument(
"--threads",
type=int,
default=None,
help="Number of CPU threads (default: auto)",
)
parser.add_argument(
"--no-x64",
action="store_true",
help="Disable 64-bit precision",
)
parser.add_argument(
"--debug",
action="store_true",
help="Disable JIT for debugging",
)
parser.add_argument(
"--info",
action="store_true",
help="Print CPU info and exit",
)
args = parser.parse_args()
if args.info:
info = get_cpu_info()
print("CPU Information:")
for key, value in info.items():
print(f" {key}: {value}")
return
env_vars = configure_xla(
num_threads=args.threads,
disable_jit=args.debug,
enable_x64=not args.no_x64,
)
print("XLA Configuration:")
for key, value in env_vars.items():
print(f" {key}={value}")
if __name__ == "__main__":
main()