"""CPU detection and HPC optimization utilities.
This module provides hardware-aware configuration for JAX workloads on CPU,
including physical core detection, NUMA topology awareness, and optimal
environment variable configuration for HPC clusters.
"""
from __future__ import annotations
import logging
import os
import platform
import re
import subprocess
import sys
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Mapping
_HPC_CONFIGURED = False
logger = logging.getLogger(__name__)
[docs]
@dataclass
class CPUInfo:
"""CPU hardware information.
Attributes:
physical_cores: Number of physical CPU cores.
logical_cores: Number of logical cores (includes hyperthreading).
numa_nodes: Number of NUMA nodes (memory domains).
architecture: CPU architecture string (e.g., 'x86_64', 'arm64').
vendor: CPU vendor (e.g., 'Intel', 'AMD', 'Apple').
model_name: Full CPU model name.
has_avx: Whether AVX instructions are available.
has_avx2: Whether AVX2 instructions are available.
has_avx512: Whether AVX-512 instructions are available.
cache_sizes: Cache sizes in bytes (L1, L2, L3).
"""
physical_cores: int
logical_cores: int
numa_nodes: int = 1
architecture: str = ""
vendor: str = ""
model_name: str = ""
has_avx: bool = False
has_avx2: bool = False
has_avx512: bool = False
cache_sizes: dict[str, int] = field(default_factory=dict)
[docs]
def detect_cpu_info() -> CPUInfo:
"""Detect CPU hardware information.
Returns:
CPUInfo dataclass with hardware details.
Note:
This function uses platform-specific methods:
- Linux: lscpu, /proc/cpuinfo
- macOS: sysctl
- Windows: wmic (basic support)
"""
system = platform.system()
if system == "Linux":
return _detect_linux_cpu()
elif system == "Darwin":
return _detect_macos_cpu()
else:
# Fallback for Windows and other platforms
return _detect_fallback_cpu()
def _detect_linux_cpu() -> CPUInfo:
"""Detect CPU info on Linux using lscpu and /proc/cpuinfo."""
logical_cores = os.cpu_count() or 1
info = CPUInfo(
physical_cores=logical_cores,
logical_cores=logical_cores,
architecture=platform.machine(),
)
# Try lscpu for detailed info
lscpu_success = False
try:
result = subprocess.run(
["lscpu"],
capture_output=True,
text=True,
timeout=5,
check=False,
)
if result.returncode == 0:
lscpu_success = True
info = _parse_lscpu(result.stdout, info)
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
pass
if lscpu_success and info.physical_cores == info.logical_cores:
sysfs_physical_cores = _detect_linux_physical_cores_from_sysfs()
if sysfs_physical_cores is not None:
info.physical_cores = sysfs_physical_cores
elif info.logical_cores > 1:
warnings.warn(
"Could not disambiguate Linux physical CPU core count from lscpu "
"or sysfs; using logical CPU count for physical_cores.",
RuntimeWarning,
stacklevel=2,
)
# Read /proc/cpuinfo for AVX flags, model name, and vendor
try:
with open("/proc/cpuinfo") as f:
cpuinfo = f.read()
avx_detected = False
for line in cpuinfo.split("\n"):
if line.startswith("flags") and ":" in line:
flags_set = set(line.split(":", 1)[1].split())
info.has_avx = "avx" in flags_set
info.has_avx2 = "avx2" in flags_set
info.has_avx512 = any(f.startswith("avx512") for f in flags_set)
avx_detected = True
elif line.startswith("model name") and not info.model_name:
info.model_name = line.split(":", 1)[1].strip()
elif line.startswith("vendor_id") and not info.vendor:
vendor_str = line.split(":", 1)[1].strip()
if "Intel" in vendor_str:
info.vendor = "Intel"
elif "AMD" in vendor_str:
info.vendor = "AMD"
else:
info.vendor = vendor_str
if avx_detected and info.model_name and info.vendor:
break
except OSError:
pass
return info
def _detect_linux_physical_cores_from_sysfs() -> int | None:
"""Detect physical cores by deduplicating Linux topology package/core IDs."""
topology_root = Path("/sys/devices/system/cpu")
physical_cores: set[tuple[int, int]] = set()
for cpu_dir in topology_root.glob("cpu[0-9]*"):
topology_dir = cpu_dir / "topology"
try:
package_id = int((topology_dir / "physical_package_id").read_text().strip())
core_id = int((topology_dir / "core_id").read_text().strip())
except (OSError, ValueError):
continue
physical_cores.add((package_id, core_id))
if not physical_cores:
return None
return len(physical_cores)
def _safe_int(value: str) -> int | None:
"""Parse an integer from lscpu value, handling formats like '4 (2 online)'."""
try:
return int(value.split()[0])
except (ValueError, IndexError):
return None
def _parse_lscpu(output: str, info: CPUInfo) -> CPUInfo:
"""Parse lscpu output to extract CPU information."""
cores_per_socket: int | None = None
sockets: int | None = None
for line in output.split("\n"):
if ":" not in line:
continue
key, value = line.split(":", 1)
key = key.strip().lower()
value = value.strip()
if key == "cpu(s)":
parsed = _safe_int(value)
if parsed is not None:
info.logical_cores = parsed
elif key == "core(s) per socket":
cores_per_socket = _safe_int(value)
elif key == "socket(s)":
sockets = _safe_int(value)
elif key == "numa node(s)":
parsed = _safe_int(value)
if parsed is not None:
info.numa_nodes = parsed
elif key == "architecture":
info.architecture = value
elif key == "vendor id":
if "Intel" in value:
info.vendor = "Intel"
elif "AMD" in value:
info.vendor = "AMD"
else:
info.vendor = value
# Compute physical cores from cores_per_socket * sockets
if cores_per_socket is not None and sockets is not None:
info.physical_cores = cores_per_socket * sockets
return info
def _detect_macos_cpu() -> CPUInfo:
"""Detect CPU info on macOS using sysctl."""
info = CPUInfo(
physical_cores=os.cpu_count() or 1,
logical_cores=os.cpu_count() or 1,
architecture=platform.machine(),
)
def _sysctl(key: str) -> str | None:
try:
result = subprocess.run(
["sysctl", "-n", key],
capture_output=True,
text=True,
timeout=5,
check=False,
)
if result.returncode == 0:
return result.stdout.strip()
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
pass
return None
# Physical cores
val = _sysctl("hw.physicalcpu")
if val:
info.physical_cores = int(val)
# Logical cores
val = _sysctl("hw.logicalcpu")
if val:
info.logical_cores = int(val)
# CPU brand
val = _sysctl("machdep.cpu.brand_string")
if val:
info.model_name = val
if "Intel" in val:
info.vendor = "Intel"
elif "Apple" in val:
info.vendor = "Apple"
# Apple Silicon detection
if info.architecture == "arm64":
info.vendor = "Apple"
# Apple Silicon has unified memory, treat as single NUMA node
info.numa_nodes = 1
# AVX detection (Intel Macs only)
val = _sysctl("hw.optional.avx1_0")
if val:
info.has_avx = val == "1"
val = _sysctl("hw.optional.avx2_0")
if val:
info.has_avx2 = val == "1"
val = _sysctl("hw.optional.avx512f")
if val:
info.has_avx512 = val == "1"
return info
def _detect_fallback_cpu() -> CPUInfo:
"""Fallback CPU detection for Windows and other platforms."""
return CPUInfo(
physical_cores=os.cpu_count() or 1,
logical_cores=os.cpu_count() or 1,
architecture=platform.machine(),
)
[docs]
def get_optimal_batch_size(
cpu_info: CPUInfo | None = None,
data_size: int = 1000,
element_bytes: int = 8,
) -> int:
"""Calculate optimal batch size based on CPU cache hierarchy.
This heuristic aims to fit working data in L3 cache while maintaining
enough parallelism for efficient vectorization.
Args:
cpu_info: CPU information (auto-detected if None).
data_size: Size of the input data dimension.
element_bytes: Bytes per element (8 for float64, 4 for float32).
Returns:
Recommended batch size.
"""
if cpu_info is None:
cpu_info = detect_cpu_info()
if data_size <= 0:
raise ValueError(f"Invalid data size: {data_size}")
# Default L3 cache assumption: 8 MB per core, shared
l3_cache = cpu_info.cache_sizes.get("L3", 8 * 1024 * 1024 * cpu_info.physical_cores)
# Target: use ~50% of L3 for working set
target_bytes = l3_cache // 2
# Estimate batch size
batch_size = max(1, target_bytes // (data_size * element_bytes))
# Round down to power of 2 for SIMD efficiency (avoid exceeding available resources)
batch_size = 1 << (batch_size.bit_length() - 1) if batch_size > 0 else 1
# Clamp to reasonable range
return max(16, min(batch_size, 4096))
[docs]
def get_jax_cpu_flags(
cpu_info: CPUInfo | None = None,
num_devices: int | None = None,
) -> str:
"""Generate XLA_FLAGS for optimal JAX CPU execution.
Args:
cpu_info: CPU information (auto-detected if None).
num_devices: Number of CPU devices to expose (default: physical cores).
Returns:
XLA_FLAGS string to set in environment.
"""
if cpu_info is None:
cpu_info = detect_cpu_info()
if num_devices is None:
num_devices = cpu_info.physical_cores
flags = [
f"--xla_force_host_platform_device_count={num_devices}",
]
return " ".join(flags)
def _warn_if_jax_already_imported(*, strict: bool) -> None:
"""Warn when env-based JAX configuration is likely too late to take effect."""
if "jax" not in sys.modules:
return
message = (
"JAX has already been imported; XLA_FLAGS and JAX_PLATFORMS changes "
"made by configure_jax_cpu() may be ignored. Call device configuration "
"before importing JAX."
)
if strict:
raise RuntimeError(message)
warnings.warn(message, RuntimeWarning, stacklevel=3)