"""Hardware configuration and CMC backend selection.
This module provides hardware detection utilities for selecting the optimal
CMC (Consensus Monte Carlo) backend based on the available resources.
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from enum import Enum
from typing import Literal
from heterodyne.device.cpu import CPUInfo, detect_cpu_info
[docs]
class ClusterType(Enum):
"""HPC cluster job scheduler type."""
STANDALONE = "standalone"
PBS = "pbs"
SLURM = "slurm"
[docs]
class CMCBackend(Enum):
"""CMC execution backend."""
JIT = "jit" # JAX jit for single-node parallelism
MULTIPROCESSING = "multiprocessing" # Python multiprocessing
PBS = "pbs" # PBS array jobs
SLURM = "slurm" # Slurm array jobs
[docs]
@dataclass
class HardwareConfig:
"""Hardware configuration for CMC execution.
Attributes:
cpu_info: Detected CPU information.
cluster_type: Type of HPC cluster (if any).
available_cores: Number of CPU cores available for computation.
memory_gb: Available memory in GB.
recommended_backend: Recommended CMC backend based on hardware.
recommended_chains: Recommended number of MCMC chains.
max_parallel_chains: Maximum chains that can run in parallel.
"""
cpu_info: CPUInfo
cluster_type: ClusterType
available_cores: int
memory_gb: float
recommended_backend: CMCBackend
recommended_chains: int
max_parallel_chains: int
[docs]
def detect_cluster_type() -> ClusterType:
"""Detect the HPC cluster scheduler type from environment.
Returns:
ClusterType enum indicating the detected scheduler.
"""
# Check for PBS
if any(
key in os.environ for key in ["PBS_JOBID", "PBS_NODEFILE", "PBS_ENVIRONMENT"]
):
return ClusterType.PBS
# Check for Slurm
if any(
key in os.environ
for key in ["SLURM_JOB_ID", "SLURM_NODELIST", "SLURM_CLUSTER_NAME"]
):
return ClusterType.SLURM
return ClusterType.STANDALONE
[docs]
def get_available_memory() -> float:
"""Get available system memory in GB.
Returns:
Available memory in GB, or a conservative estimate if detection fails.
"""
try:
import psutil
mem = psutil.virtual_memory()
return mem.available / (1024**3)
except (ImportError, OSError):
# Fallback: assume 8 GB available
# OSError covers psutil.AccessDenied in restricted containers
return 8.0
[docs]
def detect_hardware() -> HardwareConfig:
"""Detect hardware configuration and recommend CMC settings.
Returns:
HardwareConfig with detected settings and recommendations.
Note:
This function considers:
- CPU core count (physical cores preferred)
- Available memory (for chain parallelism)
- Cluster environment (PBS/Slurm for distributed)
- NUMA topology (for jit backend)
"""
cpu_info = detect_cpu_info()
cluster_type = detect_cluster_type()
memory_gb = get_available_memory()
# Determine available cores
# In cluster jobs, respect scheduler allocation
if cluster_type == ClusterType.PBS:
ncpus = os.environ.get("PBS_NCPUS") or os.environ.get("NCPUS")
if ncpus:
try:
available_cores = int(ncpus)
except ValueError:
available_cores = cpu_info.physical_cores
else:
available_cores = cpu_info.physical_cores
elif cluster_type == ClusterType.SLURM:
cpus_per_task = os.environ.get("SLURM_CPUS_PER_TASK")
if cpus_per_task:
try:
available_cores = int(cpus_per_task)
except ValueError:
available_cores = cpu_info.physical_cores
else:
available_cores = cpu_info.physical_cores
else:
available_cores = cpu_info.physical_cores
# Determine recommended backend and chain count
backend, chains, max_parallel = _recommend_backend(
cpu_info=cpu_info,
cluster_type=cluster_type,
available_cores=available_cores,
memory_gb=memory_gb,
)
return HardwareConfig(
cpu_info=cpu_info,
cluster_type=cluster_type,
available_cores=available_cores,
memory_gb=memory_gb,
recommended_backend=backend,
recommended_chains=chains,
max_parallel_chains=max_parallel,
)
def _recommend_backend(
cpu_info: CPUInfo,
cluster_type: ClusterType,
available_cores: int,
memory_gb: float,
) -> tuple[CMCBackend, int, int]:
"""Recommend CMC backend and chain configuration.
Args:
cpu_info: CPU hardware information.
cluster_type: Detected cluster type.
available_cores: Number of available cores.
memory_gb: Available memory in GB.
Returns:
Tuple of (backend, recommended_chains, max_parallel_chains).
"""
# Memory per chain estimate: ~2-4 GB for typical XPCS models
memory_per_chain = 3.0
max_chains_by_memory = max(1, int(memory_gb / memory_per_chain))
# For cluster environments, prefer array job backends
if cluster_type == ClusterType.PBS:
# PBS: use array jobs for embarrassingly parallel chains
chains = min(8, max_chains_by_memory)
return (CMCBackend.PBS, chains, chains)
if cluster_type == ClusterType.SLURM:
# Slurm: similar to PBS
chains = min(8, max_chains_by_memory)
return (CMCBackend.SLURM, chains, chains)
# Standalone mode: choose between jit and multiprocessing
if available_cores >= 4:
# jit is efficient for 4+ cores with NUMA awareness
max_parallel = min(available_cores, max_chains_by_memory, 8)
chains = max_parallel
return (CMCBackend.JIT, chains, max_parallel)
else:
# For small core counts, multiprocessing has less overhead
max_parallel = min(available_cores, max_chains_by_memory)
chains = max(2, max_parallel) # At least 2 chains for diagnostics
return (CMCBackend.MULTIPROCESSING, chains, max_parallel)
[docs]
def get_backend_name(
backend: CMCBackend,
) -> Literal["jit", "multiprocessing", "pbs", "slurm"]:
"""Get the string name of a CMC backend for configuration.
Args:
backend: CMCBackend enum value.
Returns:
Backend name string for use in configuration.
"""
return backend.value
[docs]
def get_device_status() -> dict[str, object]:
"""Get current device configuration status.
Returns:
Dictionary with current device settings and detected hardware.
"""
hw = detect_hardware()
return {
"cpu": {
"physical_cores": hw.cpu_info.physical_cores,
"logical_cores": hw.cpu_info.logical_cores,
"numa_nodes": hw.cpu_info.numa_nodes,
"architecture": hw.cpu_info.architecture,
"vendor": hw.cpu_info.vendor,
"model": hw.cpu_info.model_name,
},
"cluster": {
"type": hw.cluster_type.value,
"available_cores": hw.available_cores,
"memory_gb": round(hw.memory_gb, 1),
},
"cmc_recommendation": {
"backend": hw.recommended_backend.value,
"chains": hw.recommended_chains,
"max_parallel": hw.max_parallel_chains,
},
"environment": {
"XLA_FLAGS": os.environ.get("XLA_FLAGS", ""),
"JAX_PLATFORMS": os.environ.get("JAX_PLATFORMS", ""),
"OMP_NUM_THREADS": os.environ.get("OMP_NUM_THREADS", ""),
},
}