"""Multiprocessing backend for CMC sharded MCMC execution.
This module provides parallel NUTS execution using Python's multiprocessing
module for CPU-based parallelism across CMC shards. Each shard runs as a
separate spawned process with its own JAX initialization, avoiding JAX
shared-state issues across forked processes.
Key design decisions:
- ``mp_context="spawn"`` (not fork): JAX cannot be safely shared across
fork. Spawned workers re-initialize JAX from scratch.
- All NumPyro imports inside worker functions: spawn safety requires that
no JAX/NumPyro state exists at import time in the child process.
- Shared memory for common data: ``SharedDataManager`` places config,
parameter-space state, and per-shard arrays in shared memory once,
avoiding redundant serialization overhead through spawn.
- LPT scheduling: shards dispatched highest-cost-first to minimize
tail latency on identical parallel workers.
- Heartbeat thread inside each worker: emits liveness pings so the parent
can detect frozen processes and apply ``heartbeat_timeout``.
- Adaptive polling: poll interval grows when no shard has completed
recently, shrinking CPU overhead during long-running shards.
Optimizations carried over from homodyne v2.22.2:
- Batch PRNG key generation: pre-generate all shard keys in one JAX call.
- Per-shard shared memory (packed format): 4 segments total regardless
of shard count, avoiding fd exhaustion.
- deque for pending shards: O(1) popleft instead of O(n) list.pop(0).
- Persistent compilation cache via ``jax.config.update`` (env var alone
insufficient in JAX 0.8+, ``min_compile_time`` lowered to 0).
This backend is selected when ``config.num_chains >= 3``, or when
``config.backend_name == "multiprocessing"``.
"""
from __future__ import annotations
import logging
import multiprocessing as mp
import multiprocessing.shared_memory
import os
import queue
import threading
import time
from collections import deque
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
from tqdm import tqdm
from heterodyne.optimization.cmc.backends.base import (
_NUTS_EXTRA_FIELDS,
BackendCapabilities,
CMCBackend,
)
from heterodyne.utils.logging import get_logger, log_exception, with_context
if TYPE_CHECKING:
import jax.numpy as jnp
from heterodyne.optimization.cmc.config import CMCConfig
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
#: Number of heterodyne model parameters (14-parameter two-component model).
_N_PARAMS_HETERODYNE: int = 14
#: Memory constants for estimation.
_BYTES_PER_FLOAT64: int = 8
_CPU_MEMORY_OVERHEAD_FACTOR: float = 8.0 # Conservative: 14 params x gradient bufs
_BYTES_PER_GB: float = 1024.0**3
[docs]
@dataclass(frozen=True, slots=True)
class ArraySpec:
"""Schema for one packed shared-memory array key.
Attributes:
expected_dtype: Documented dtype. Advisory only — the packer uses
the actual array's dtype. Logged in shape-mismatch errors.
allow_none: Whether a caller may pass ``None`` for this key, in
which case it is stored as a zero-length sentinel.
description: Human-readable label, surfaced in error messages so
shape/size mismatches identify the offending physical array
(e.g. "two-time correlation matrix", not just "c2_data").
"""
expected_dtype: str
allow_none: bool
description: str
#: Self-describing schema for per-shard arrays packed into shared memory.
#: The per-shard ref dict carries ``shape: tuple[int, ...]`` per key so the
#: worker reshapes on load (eliminates the het_7221ba99 / het_457cc550 class
#: of silent shape-loss bugs by construction; codex C2 + gemini G3).
#:
#: Element-wise (random) shards use ``t1``/``t2``/``time_grid``; contiguous
#: shards use ``t``. Both sets must be listed here so the parent→worker
#: shared-memory pipeline forwards them; missing keys are silently dropped
#: (None in worker) — dispatch in the worker is keyed on which keys are
#: present in the loaded ``shard_data`` dict, not on this schema.
#:
#: SCHEMA INVARIANT (must hold across refactors):
#:
#: * ``ArraySpec`` instances live in **code**. They are re-imported by
#: every spawned worker via this module load, so parent and worker
#: always agree on the schema even though they don't share memory.
#:
#: * Per-shard **runtime metadata** (``shape``, ``dtype``, ``offset``,
#: ``size``, ``shm_name``) lives in the **ref dict** that travels
#: parent→worker via the spawn IPC wire format.
#:
#: * Do NOT serialise ``ArraySpec`` (or anything else from
#: ``_SHARD_ARRAY_SPECS``) into the ref dict. That re-introduces a
#: wire-format/schema-version coupling that the code-resident schema
#: specifically eliminates. If you need a new per-shard field, add it
#: to the ref dict and the unpacker — not to ``ArraySpec``.
_SHARD_ARRAY_SPECS: dict[str, ArraySpec] = {
"c2_data": ArraySpec(
expected_dtype="float64",
allow_none=False,
description="two-time correlation matrix",
),
"sigma": ArraySpec(
expected_dtype="float64",
allow_none=True,
description="per-pair measurement uncertainty",
),
"t": ArraySpec(
expected_dtype="float64",
allow_none=True,
description="contiguous-shard time axis",
),
"t1": ArraySpec(
expected_dtype="float64",
allow_none=True,
description="element-wise shard t1 coordinates",
),
"t2": ArraySpec(
expected_dtype="float64",
allow_none=True,
description="element-wise shard t2 coordinates",
),
"time_grid": ArraySpec(
expected_dtype="float64",
allow_none=True,
description="element-wise shard cumsum grid",
),
"weights": ArraySpec(
expected_dtype="float64",
allow_none=True,
description="per-pair weights",
),
}
#: Backwards-compat alias: callers that iterate keys still work.
_SHARD_ARRAY_KEYS: tuple[str, ...] = tuple(_SHARD_ARRAY_SPECS.keys())
# ---------------------------------------------------------------------------
# SharedDataManager
# ---------------------------------------------------------------------------
[docs]
class SharedDataManager:
"""Manages shared memory blocks for data common to all CMC shards.
Uses ``multiprocessing.shared_memory`` to share config dicts, parameter-
space state, initial values, and per-shard arrays across spawned worker
processes, avoiding redundant serialisation per shard.
Serialization note: uses ``pickle`` internally for trusted internal dicts
only (``CMCConfig.to_dict()``, parameter-space dict). This matches the
existing multiprocessing behaviour which also serialises all process
arguments. External/untrusted data is never serialised here.
Must be used as a context manager or ``cleanup()`` called in a
``finally`` block to avoid leaked shared memory segments on Linux.
Attributes:
_shared_blocks: All allocated ``SharedMemory`` segments.
_refs: Named references returned to callers.
"""
[docs]
def __init__(self) -> None:
self._shared_blocks: list[mp.shared_memory.SharedMemory] = []
self._refs: dict[str, dict[str, Any]] = {}
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
def create_shared_bytes(self, name: str, data: bytes) -> dict[str, Any]:
"""Store raw bytes in a shared memory segment.
Args:
name: Logical name for this block (used for bookkeeping only).
data: Bytes to copy into shared memory.
Returns:
Reference dict with ``shm_name``, ``size``, and ``type`` keys.
"""
shm = mp.shared_memory.SharedMemory(create=True, size=max(1, len(data)))
shm.buf[: len(data)] = data
self._shared_blocks.append(shm)
ref: dict[str, Any] = {
"shm_name": shm.name,
"size": len(data),
"type": "bytes",
}
self._refs[name] = ref
return ref
[docs]
def create_shared_array(self, name: str, array: np.ndarray) -> dict[str, Any]:
"""Store a numpy array in a shared memory segment.
Args:
name: Logical name for this block.
array: Array to copy into shared memory (contiguous float64).
Returns:
Reference dict with ``shm_name``, ``shape``, ``dtype``, and
``type`` keys.
"""
arr = np.ascontiguousarray(array)
shm = mp.shared_memory.SharedMemory(create=True, size=max(1, arr.nbytes))
shared_arr = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
shared_arr[:] = arr
self._shared_blocks.append(shm)
ref: dict[str, Any] = {
"shm_name": shm.name,
"shape": arr.shape,
"dtype": str(arr.dtype),
"type": "array",
}
self._refs[name] = ref
return ref
[docs]
def create_shared_dict(self, name: str, d: dict[str, Any]) -> dict[str, Any]:
"""Serialise a trusted internal dict into shared memory.
Only used for ``CMCConfig.to_dict()`` and parameter-space dicts.
External/untrusted data is never passed here.
Args:
name: Logical name for this block.
d: Dict to serialise into shared memory.
Returns:
Reference dict (same as :meth:`create_shared_bytes`).
"""
import pickle as _pkl # noqa: S403 — trusted internal data only # nosec B403
return self.create_shared_bytes(name, _pkl.dumps(d))
[docs]
def create_shared_shard_arrays(
self,
shard_data_list: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Place per-shard numpy arrays into packed shared memory.
Instead of creating one ``SharedMemory`` segment per array per shard
(``n_shards * 4`` = many file descriptors), this concatenates all
shard arrays for each key into a single shared memory block. Only
``len(_SHARD_ARRAY_KEYS)`` segments are created regardless of shard
count.
Args:
shard_data_list: List of shard data dicts, each containing numpy
arrays keyed by ``_SHARD_ARRAY_KEYS`` plus a scalar
``noise_scale``. Arrays for ``sigma`` and ``weights`` may be
``None``, in which case a zero-length sentinel is stored.
Returns:
List of lightweight shard references (shm names + offsets).
Each ref dict is small enough to serialise cheaply through spawn.
"""
n_shards = len(shard_data_list)
key_meta: dict[str, dict[str, Any]] = {}
for key, spec in _SHARD_ARRAY_SPECS.items():
arrays: list[np.ndarray] = []
sizes: list[int] = []
shapes: list[tuple[int, ...]] = []
dtypes: list[str] = []
for sd in shard_data_list:
raw = sd.get(key)
if raw is None:
if not spec.allow_none:
raise ValueError(
f"Shard array {key!r} ({spec.description}) is "
"required but received None; ArraySpec.allow_none=False."
)
arr = np.empty(0, dtype=np.float64)
raw_shape: tuple[int, ...] = ()
else:
raw_np = np.asarray(raw)
raw_shape = tuple(raw_np.shape)
arr = np.ascontiguousarray(raw_np.ravel())
arrays.append(arr)
sizes.append(arr.shape[0])
shapes.append(raw_shape)
dtypes.append(str(arr.dtype))
# Use dtype from first non-empty array, or float64 fallback
reference_dtype = next(
(dtypes[i] for i in range(n_shards) if sizes[i] > 0),
"float64",
)
cast_arrays = [
a.astype(reference_dtype) if a.size > 0 else a for a in arrays
]
combined = (
np.concatenate(cast_arrays)
if any(a.size > 0 for a in cast_arrays)
else np.empty(0, dtype=reference_dtype)
)
shm = mp.shared_memory.SharedMemory(
create=True, size=max(1, combined.nbytes)
)
shared_arr = np.ndarray(
combined.shape, dtype=combined.dtype, buffer=shm.buf
)
if combined.size > 0:
shared_arr[:] = combined
self._shared_blocks.append(shm)
# Prefix-sum offsets for per-shard slicing
offsets: list[int] = [0]
for s in sizes[:-1]:
offsets.append(offsets[-1] + s)
key_meta[key] = {
"shm_name": shm.name,
"dtype": reference_dtype,
"offsets": offsets,
"sizes": sizes,
"shapes": shapes,
}
shard_refs: list[dict[str, Any]] = []
for i in range(n_shards):
ref: dict[str, Any] = {
"noise_scale": shard_data_list[i].get("noise_scale", 0.1),
}
for key in _SHARD_ARRAY_SPECS:
meta = key_meta[key]
ref[key] = {
"shm_name": meta["shm_name"],
"dtype": meta["dtype"],
"offset": meta["offsets"][i],
"size": meta["sizes"][i],
# Required (codex C2 + gemini G3): worker reshapes on load
# to preserve 2-D structure that ``ravel()`` flattened on pack.
"shape": meta["shapes"][i],
}
shard_refs.append(ref)
return shard_refs
[docs]
def cleanup(self) -> None:
"""Release all shared memory blocks.
Idempotent — safe to call more than once. Must be called in a
``finally`` block to avoid leaked segments.
"""
for shm in self._shared_blocks:
try:
shm.close()
shm.unlink()
except (FileNotFoundError, OSError):
pass
self._shared_blocks.clear()
self._refs.clear()
def __enter__(self) -> SharedDataManager:
return self
def __exit__(self, *exc: object) -> None:
self.cleanup()
# ---------------------------------------------------------------------------
# Shared-memory reconstruction helpers (called inside worker processes)
# ---------------------------------------------------------------------------
def _load_shared_bytes(ref: dict[str, Any]) -> bytes:
"""Reconstruct raw bytes from a shared memory reference."""
shm = mp.shared_memory.SharedMemory(name=ref["shm_name"], create=False)
try:
data = bytes(shm.buf[: ref["size"]])
finally:
shm.close()
return data
def _load_shared_dict(ref: dict[str, Any]) -> dict[str, Any]:
"""Reconstruct a trusted internal dict from a shared memory reference.
Only called inside worker processes for ``CMCConfig`` and parameter-space
dicts that were serialised by the parent process — never for external data.
"""
import pickle as _pkl # noqa: S403 — trusted internal data only # nosec B403
return _pkl.loads(_load_shared_bytes(ref)) # noqa: S301 # nosec B301
def _load_shared_array(ref: dict[str, Any]) -> np.ndarray:
"""Reconstruct a numpy array from a shared memory reference (copying)."""
shm = mp.shared_memory.SharedMemory(name=ref["shm_name"], create=False)
try:
arr = np.ndarray(
ref["shape"], dtype=np.dtype(ref["dtype"]), buffer=shm.buf
).copy()
finally:
shm.close()
return arr
def _load_shared_shard_data(shard_ref: dict[str, Any]) -> dict[str, Any]:
"""Reconstruct per-shard arrays from packed shared memory.
Each array key maps to a single concatenated ``SharedMemory`` block
shared across all shards. The per-shard ref carries ``offset``
(element index) and ``size`` (element count) to slice this shard's
portion. Sentinel entries with ``size == 0`` are returned as ``None``.
Args:
shard_ref: Lightweight shard reference created by
:meth:`SharedDataManager.create_shared_shard_arrays`.
Returns:
Shard data dict with numpy arrays (copied from shared memory)
and scalar ``noise_scale``. Arrays that were originally ``None``
are returned as ``None``.
"""
shard_data: dict[str, Any] = {"noise_scale": shard_ref["noise_scale"]}
for key, spec in _SHARD_ARRAY_SPECS.items():
arr_ref = shard_ref[key]
size = arr_ref["size"]
if size == 0:
shard_data[key] = None
continue
# Required schema field (codex C2 + gemini G3): the packer always
# stores ``shape`` so the worker can reverse the ``ravel()`` that
# the packer applies. Missing key indicates a stale ref dict from
# an old code path or a foreign serialiser — fail loud, not silent.
if "shape" not in arr_ref:
raise KeyError(
f"Shard ref for {key!r} ({spec.description}) is missing the "
"required 'shape' key. Ref dict must be produced by "
"SharedDataManager.create_shared_shard_arrays — schema is "
"not backwards-compatible with pre-Tier-2 packed refs."
)
shape: tuple[int, ...] = tuple(arr_ref["shape"])
expected_elements = int(np.prod(shape)) if shape else 0
# Empty shape () represents a scalar; otherwise prod must match size.
if shape and expected_elements != size:
raise ValueError(
f"Shard array {key!r} ({spec.description}, "
f"expected_dtype={spec.expected_dtype}): declared shape "
f"{shape} (prod={expected_elements}) does not match packed "
f"size {size}."
)
shm = mp.shared_memory.SharedMemory(name=arr_ref["shm_name"], create=False)
try:
dtype = np.dtype(arr_ref["dtype"])
offset = arr_ref["offset"]
total_elements = len(shm.buf) // dtype.itemsize
full_arr = np.ndarray((total_elements,), dtype=dtype, buffer=shm.buf)
arr = full_arr[offset : offset + size].copy()
finally:
shm.close()
if shape:
arr = arr.reshape(shape)
shard_data[key] = arr
return shard_data
# ---------------------------------------------------------------------------
# PRNG key helpers
# ---------------------------------------------------------------------------
def _generate_shard_keys(n_shards: int, seed: int = 42) -> list[tuple[int, ...]]:
"""Pre-generate all shard PRNG keys in a single JAX call.
Amortises JAX compilation overhead across all shards by generating
keys in the parent process before spawning workers.
Args:
n_shards: Number of shards to generate keys for.
seed: Base seed for PRNG key generation.
Returns:
List of raw ``uint32`` tuples that can be passed through spawn
and reconstructed via ``jax.random.wrap_key_data`` in workers.
"""
import jax
import jax.numpy as jnp
base_key = jax.random.PRNGKey(seed)
all_keys = jax.random.split(base_key, n_shards + 1)
shard_keys = all_keys[1:]
key_tuples: list[tuple[int, ...]] = []
for key in shard_keys:
raw = jax.random.key_data(key).flatten().astype(jnp.uint32)
key_tuples.append(tuple(int(x) for x in raw))
return key_tuples
# ---------------------------------------------------------------------------
# LPT scheduling
# ---------------------------------------------------------------------------
def _compute_lpt_schedule(
shard_data_list: list[dict[str, Any]],
) -> deque[int]:
"""Order shard indices by descending estimated cost (LPT heuristic).
Cost = ``n_points * (1 + normalised_noise)``, where noise is linearly
scaled to ``[0, 1]`` across shards. Dispatching the most expensive
shards first minimises tail latency on identical parallel workers.
Args:
shard_data_list: Shard dicts with ``"c2_data"`` (array or ``None``)
and ``"noise_scale"`` (float).
Returns:
Shard indices sorted by descending cost as a ``deque`` for O(1)
``popleft``.
"""
n_shards = len(shard_data_list)
sizes: list[int] = []
for i in range(n_shards):
c2 = shard_data_list[i].get("c2_data")
sizes.append(len(c2) if c2 is not None else 1)
noises = [
float(shard_data_list[i].get("noise_scale", 0.1)) for i in range(n_shards)
]
max_noise = max(noises) if noises else 1.0
min_noise = min(noises) if noises else 0.0
noise_range = max_noise - min_noise
if noise_range > 0.0:
costs = [
sizes[i] * (1.0 + (noises[i] - min_noise) / noise_range)
for i in range(n_shards)
]
else:
costs = [float(s) for s in sizes]
return deque(sorted(range(n_shards), key=lambda i: costs[i], reverse=True))
# ---------------------------------------------------------------------------
# Worker-process helpers
# ---------------------------------------------------------------------------
def _get_physical_cores() -> int:
"""Return physical core count, falling back to ``os.cpu_count() // 2``."""
try:
import psutil
physical = psutil.cpu_count(logical=False)
if physical is not None:
return physical
except ImportError:
pass
return max(1, (os.cpu_count() or 1) // 2)
def _compute_threads_per_worker(total_threads: int, workers: int) -> int:
"""Derive a conservative per-worker thread budget to avoid oversubscription.
Uses physical cores (not logical) as the safe pool.
Args:
total_threads: Total logical thread count available.
workers: Number of concurrent worker processes.
Returns:
Number of threads to allocate per worker (minimum 1).
"""
physical_cores = _get_physical_cores()
safe_pool = max(1, min(total_threads, physical_cores))
worker_count = max(1, workers)
return max(1, safe_pool // worker_count)
def _estimate_shard_time(
n_data: int,
n_params: int,
n_samples: int,
) -> float:
"""Rough estimate of per-shard wall-clock time in seconds.
Uses a simple linear model calibrated on 14-parameter NUTS runs:
approximately 0.5 ms per data point per sample, times a log factor
for parameter count.
Args:
n_data: Number of data points in the shard.
n_params: Number of varying model parameters.
n_samples: Total MCMC draws (warmup + samples).
Returns:
Estimated duration in seconds (lower bound; actual cost varies).
"""
import math
base_s_per_point_per_sample = 5e-4
param_factor = math.log1p(n_params) / math.log1p(_N_PARAMS_HETERODYNE)
return base_s_per_point_per_sample * n_data * n_samples * param_factor
def _validate_worker_result(result: dict[str, Any]) -> None:
"""Validate that a worker result dict is internally consistent.
Checks for non-finite sample values and shape consistency across
parameters. Raises ``ValueError`` on failure so the caller can mark
the shard as failed rather than propagating corrupt data into the
consensus step.
Args:
result: Worker result dict (as returned by ``_run_shard_worker``).
Raises:
ValueError: If samples contain NaN/Inf or shapes are inconsistent.
"""
samples: dict[str, np.ndarray] = result.get("samples", {})
if not samples:
raise ValueError("Worker result contains no samples dict")
expected_shape: tuple[int, ...] | None = None
for name, arr in samples.items():
if not isinstance(arr, np.ndarray):
raise ValueError(
f"Sample array for '{name}' is not a numpy array "
f"(got {type(arr).__name__})"
)
nan_count = int(np.sum(~np.isfinite(arr)))
if nan_count > 0:
raise ValueError(
f"Sample array for '{name}' contains {nan_count} non-finite values"
)
if expected_shape is None:
expected_shape = arr.shape
elif arr.shape != expected_shape:
raise ValueError(
f"Shape mismatch: '{name}' has shape {arr.shape}, "
f"expected {expected_shape}"
)
def _init_worker_jax(threads_per_worker: int, num_chains: int) -> None:
"""Per-worker JAX initialisation called before any JAX/NumPyro imports.
Configures XLA/OpenMP environment variables and calls
``jax.config.update`` to enable float64 and the persistent
compilation cache.
Args:
threads_per_worker: Number of OpenMP/MKL threads to allocate.
num_chains: Number of MCMC chains (sets XLA virtual device count).
"""
import re as _re
# Thread pinning — avoid oversubscription across concurrent workers.
# CRITICAL: clear OMP_PROC_BIND / OMP_PLACES to prevent all workers
# competing for the same physical cores (massive contention on NUMA).
os.environ["OMP_NUM_THREADS"] = str(threads_per_worker)
os.environ["MKL_NUM_THREADS"] = str(threads_per_worker)
os.environ["OPENBLAS_NUM_THREADS"] = str(threads_per_worker)
os.environ["VECLIB_MAXIMUM_THREADS"] = str(threads_per_worker)
os.environ.pop("OMP_PROC_BIND", None)
os.environ.pop("OMP_PLACES", None)
# Enable float64 BEFORE importing JAX. Spawned workers start fresh
# processes and do not inherit the parent's jax.config state.
os.environ["JAX_ENABLE_X64"] = "true"
# Persistent compilation cache so later workers reuse compiled XLA
# programs from the first worker (JAX 0.8+: env var alone insufficient).
cache_dir = os.environ.get(
"JAX_COMPILATION_CACHE_DIR",
str(Path(os.path.expanduser("~/.cache/heterodyne/jax_cache"))),
)
os.environ["JAX_COMPILATION_CACHE_DIR"] = cache_dir
# Set XLA virtual device count to num_chains so parallel chain_method
# works correctly with multiple virtual CPU devices.
_xla_flags = os.environ.get("XLA_FLAGS", "")
_xla_flags = _re.sub(r"--xla_force_host_platform_device_count=\d+", "", _xla_flags)
os.environ["XLA_FLAGS"] = (
_xla_flags.strip() + f" --xla_force_host_platform_device_count={num_chains}"
)
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_compilation_cache_dir", cache_dir)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
# ---------------------------------------------------------------------------
# Joint multi-phi pooled-model parallel shards (homodyne parity)
# ---------------------------------------------------------------------------
# These run the POOLED joint model (one NUTS per shard over flat (n_total,)
# data) in parallel worker processes, distinct from the single-angle shard
# worker below. The worker defers the heterodyne.core import so JAX loads only
# after _init_worker_jax has configured float64 / cache in the child.
def _run_joint_pooled_shard(payload: dict[str, Any]) -> Any:
"""Worker entry: run one pooled-model shard and return its CMCResult.
Imported lazily so the spawned child configures JAX (via the pool
initializer ``_init_worker_jax``) before the heterodyne core — which
imports JAX at module top — is loaded.
"""
from heterodyne.optimization.cmc.core import _joint_pooled_nuts_run
return _joint_pooled_nuts_run(**payload)
def _run_joint_pooled_shard_indexed(
item: tuple[int, dict[str, Any]],
) -> tuple[int, Any]:
"""Index-preserving wrapper so ``imap_unordered`` results can be reordered.
``imap_unordered`` yields results in completion order (enabling live
progress); tagging each result with its original shard index lets the
parent restore input order. Must stay module-level so the spawn pool can
serialise it for dispatch to child processes.
"""
idx, payload = item
return idx, _run_joint_pooled_shard(payload)
def _joint_payload_param_names(payload: dict[str, Any]) -> list[str]:
"""Reconstruct a shard's parameter-name vector from its payload.
Mirrors the ``parameter_names`` ordering built inside
``_joint_pooled_nuts_run`` (physics names, then ``contrast_{i}`` /
``offset_{i}`` per angle) so timed-out shards can be backfilled with a
correctly-shaped failed placeholder when no shard completed.
"""
space = payload["space"]
n_phi = int(payload["n_phi"])
physics_names = [n for n in space.varying_names if n not in ("contrast", "offset")]
contrast_names = [f"contrast_{i}" for i in range(n_phi)]
offset_names = [f"offset_{i}" for i in range(n_phi)]
return physics_names + contrast_names + offset_names
[docs]
def run_joint_pooled_shards_parallel(
payloads: list[dict[str, Any]],
*,
n_workers: int,
num_chains: int,
progress_bar: bool = True,
per_shard_timeout: int = 7200,
) -> list[Any]:
"""Run pooled-model shard payloads across a spawn process pool.
Reuses the proven ``_init_worker_jax`` initializer (float64, compilation
cache, OpenMP thread pinning). Each element of ``payloads`` is the kwargs
dict for :func:`heterodyne.optimization.cmc.core._joint_pooled_nuts_run`.
Returns the per-shard ``CMCResult`` objects in input order.
Progress is reported two ways so long runs never block silently: a
``tqdm`` bar (interactive terminals) AND periodic ``logger.info``
heartbeats (visible in log files, where tqdm is not). The heartbeat fires
every ``heartbeat`` seconds while waiting and reports elapsed time,
completed/total shards, and time since the last completion.
``per_shard_timeout`` bounds the wait: if **no** shard completes within
``per_shard_timeout`` seconds (every in-flight worker has exceeded the
budget), the pool is terminated and the remaining shards are returned as
failed placeholders. This prevents the het_491ee368 failure mode where a
divergence-storming shard ran 21 000 s (3x the 7200 s budget) unkilled.
Terminated shards become failed sub-posteriors that Consensus MC drops via
its ``convergence_passed`` gate — the run degrades, it does not hang.
"""
total_threads = os.cpu_count() or 1
threads_per_worker = _compute_threads_per_worker(total_threads, n_workers)
ctx = mp.get_context("spawn")
n_shards = len(payloads)
results: list[Any] = [None] * n_shards
indexed = list(enumerate(payloads))
total_divergences = 0
completed = 0
# Heartbeat cadence: frequent enough to reassure, capped at 5 min so the
# log is not flooded. Always < per_shard_timeout so the budget is checked.
heartbeat = max(30, min(per_shard_timeout // 4, 300))
start = time.monotonic()
last_completion = start
timed_out = False
with (
ctx.Pool(
processes=n_workers,
initializer=_init_worker_jax,
initargs=(threads_per_worker, num_chains),
) as pool,
tqdm(
total=n_shards,
desc=f"CMC joint shards ({n_workers} workers)",
unit="shard",
disable=not progress_bar,
) as pbar,
):
result_iter = pool.imap_unordered(_run_joint_pooled_shard_indexed, indexed)
while completed < n_shards:
try:
idx, result = result_iter.next(timeout=heartbeat)
except mp.TimeoutError:
now = time.monotonic()
since_last = now - last_completion
logger.info(
"[CMC joint] heartbeat: %d/%d shards complete, %.0fs elapsed, "
"%.0fs since last completion (%d workers, timeout=%ds)",
completed,
n_shards,
now - start,
since_last,
n_workers,
per_shard_timeout,
)
if per_shard_timeout > 0 and since_last > per_shard_timeout:
logger.error(
"[CMC joint] per_shard_timeout=%ds exceeded: no shard "
"completed in %.0fs; terminating %d in-flight workers. "
"%d/%d shards done; the rest are marked failed and dropped "
"from consensus. Likely cause: divergence storm from a "
"poor warm-start — run NLSQ first (optimizer: both) or "
"lower max_points_per_shard.",
per_shard_timeout,
since_last,
n_workers,
completed,
n_shards,
)
pool.terminate()
timed_out = True
break
continue
results[idx] = result
completed += 1
last_completion = time.monotonic()
total_divergences += int(getattr(result, "divergences", 0) or 0)
pbar.update(1)
pbar.set_postfix(
shard=idx,
div=total_divergences,
ok=bool(getattr(result, "convergence_passed", False)),
)
logger.info(
"[CMC joint] shard %d complete (%d/%d; divergences=%d, status=%s)",
idx,
completed,
n_shards,
int(getattr(result, "divergences", 0) or 0),
getattr(result, "convergence_status", "unknown"),
)
if timed_out:
# Backfill any shard the terminated pool never returned with a failed
# placeholder so Consensus MC downstream sees a full, correctly-shaped
# list (no ``None``); failed shards are dropped by the convergence gate.
from heterodyne.optimization.cmc.core import _create_failed_result
done_names = next((r.parameter_names for r in results if r is not None), None)
for i, r in enumerate(results):
if r is None:
names = done_names or _joint_payload_param_names(payloads[i])
results[i] = _create_failed_result(
list(names), "shard terminated: per_shard_timeout exceeded"
)
return results
# ---------------------------------------------------------------------------
# Core shard worker (runs entirely inside the spawned child process)
# ---------------------------------------------------------------------------
def _run_shard_worker(
shard_idx: int,
shard_data: dict[str, Any],
config_dict: dict[str, Any],
initial_values: dict[str, Any] | None,
result_queue: mp.Queue | None,
rng_key_tuple: tuple[int, ...] | None,
) -> dict[str, Any]:
"""Run NUTS on a single data shard inside a spawned worker process.
All imports of JAX and NumPyro occur here, after ``_init_worker_jax``
has already set environment variables.
Emits periodic heartbeat messages to ``result_queue`` via a background
thread so the parent can detect frozen workers and apply
``heartbeat_timeout``.
Args:
shard_idx: Shard index, used for logging and result identification.
shard_data: Dict with ``c2_data``, ``sigma``, ``t``, ``weights``,
``noise_scale``, ``q``, ``dt``, ``phi_angle``, ``contrast``,
``offset``, optionally ``reparam_config_dict``,
``parameter_space_dict``, and ``n_phi``.
config_dict: Serialised ``CMCConfig`` (from ``CMCConfig.to_dict()``).
initial_values: Optional NLSQ warm-start values (parameter name to
float).
result_queue: Multiprocessing queue for heartbeat and result
delivery. May be ``None`` in unit-test contexts.
rng_key_tuple: Pre-generated PRNG key as raw ``uint32`` tuple.
Falls back to ``PRNGKey(42 + shard_idx)`` when ``None``.
Returns:
Result dict with keys: ``type``, ``success``, ``shard_idx``,
``samples``, ``n_chains``, ``n_samples``, ``param_names``,
``extra_fields``, ``duration``, ``stats``. On failure: ``error``,
``error_category``, ``traceback``.
"""
# All JAX/NumPyro imports are deferred to here for spawn safety.
import jax
import jax.numpy as jnp
from numpyro.infer import MCMC, NUTS
from numpyro.infer import initialization as numpyro_init
from heterodyne.config.parameter_names import ALL_PARAM_NAMES
from heterodyne.core.jax_backend import compute_c2_heterodyne
from heterodyne.core.physics_cmc import (
compute_c2_elementwise as _compute_c2_elementwise,
)
from heterodyne.core.physics_cmc import (
precompute_shard_grid as _precompute_shard_grid,
)
from heterodyne.optimization.cmc.config import CMCConfig
start_time = time.perf_counter()
worker_logger = get_logger(
__name__,
context={"run": config_dict.get("run_id"), "shard": shard_idx},
)
n_points = (
len(shard_data["c2_data"]) if shard_data.get("c2_data") is not None else 0
)
worker_logger.info(
"Shard %d starting: %d points",
shard_idx,
n_points,
)
# ------------------------------------------------------------------
# Heartbeat thread — emits liveness pings to the parent queue
# ------------------------------------------------------------------
stop_hb = threading.Event()
heartbeat_interval = 30.0
def _heartbeat_loop() -> None:
while True:
# Wait for stop signal or timeout
if stop_hb.wait(timeout=heartbeat_interval):
break
payload: dict[str, Any] = {
"type": "heartbeat",
"shard_idx": shard_idx,
"elapsed": time.perf_counter() - start_time,
}
if result_queue is not None:
try:
result_queue.put_nowait(payload)
except Exception: # noqa: BLE001 — best-effort heartbeat
pass
hb_thread = threading.Thread(target=_heartbeat_loop, daemon=True)
hb_thread.start()
try:
config = CMCConfig.from_dict(config_dict)
# Reconstruct PRNG key from raw uint32 tuple
if rng_key_tuple is not None:
rng_key = jax.random.wrap_key_data(
jnp.array(rng_key_tuple, dtype=jnp.uint32)
)
else:
rng_key = jax.random.PRNGKey(42 + shard_idx)
# Convert shard arrays to JAX.
# Two wire formats are accepted (set by fit_cmc_sharded):
#
# Element-wise (random strategy): "t1", "t2", "time_grid" keys present.
# c2_data is a flat 1-D array of shape (n_pairs,).
# Worker builds ShardGrid and calls compute_c2_elementwise → (n_pairs,).
#
# Meshgrid (contiguous strategy): "t" key present.
# c2_data is a 2-D array of shape (shard_n, shard_n).
# Worker calls compute_c2_heterodyne → (shard_n, shard_n).
c2_jax = jnp.asarray(shard_data["c2_data"])
# Build ShardGrid for element-wise path if t1/t2/time_grid are present
_shard_grid = None
if shard_data.get("t1") is not None and shard_data.get("time_grid") is not None:
_t1_jax = jnp.asarray(shard_data["t1"])
_t2_jax = jnp.asarray(shard_data["t2"])
_time_grid_jax = jnp.asarray(shard_data["time_grid"])
_shard_grid = _precompute_shard_grid(_time_grid_jax, _t1_jax, _t2_jax)
t_jax: jnp.ndarray | None = None # not used in element-wise path
else:
t_raw = shard_data.get("t")
t_jax = jnp.asarray(t_raw) if t_raw is not None else None
# noise_scale is the scalar prior centre for the sampled sigma site
# (homodyne parity). Prefer the explicit shard_data['noise_scale'];
# fall back to the mean of any provided sigma array, or MAD of c2.
sigma_raw = shard_data.get("sigma")
if "noise_scale" in shard_data and shard_data["noise_scale"] is not None:
noise_scale: float = float(shard_data["noise_scale"])
elif sigma_raw is not None:
noise_scale = float(jnp.mean(jnp.asarray(sigma_raw)))
else:
median_val = float(jnp.median(c2_jax))
mad = float(jnp.median(jnp.abs(c2_jax - median_val)))
noise_scale = max(mad * 1.4826, 1e-6)
# Retain a sigma_jax reference only to keep the legacy `del c2_jax, sigma_jax, ...`
# cleanup line working — but the model itself samples sigma internally.
sigma_jax: jnp.ndarray | float = noise_scale
q_val: float = float(shard_data.get("q", 1.0))
dt_val: float = float(shard_data.get("dt", 1e-3))
phi_angle: float = float(shard_data.get("phi_angle", 0.0))
contrast: float = float(shard_data.get("contrast", 1.0))
offset: float = float(shard_data.get("offset", 1.0))
# reparam_config_dict is included in the wire format for forward
# compatibility, but _shard_model samples directly from physics-space
# priors and does not apply a back-transform. The deserialization
# below is intentionally kept so that the key remains accepted without
# error; reparam_config itself is not consumed after the C5 fix that
# removed the broken reparam_to_physics_jax back-transform call.
reparam_config_dict = shard_data.get("reparam_config_dict")
if reparam_config_dict is not None:
from heterodyne.optimization.cmc.reparameterization import ReparamConfig
_reparam_config = ReparamConfig(**reparam_config_dict) # noqa: F841 — wire-compat only
# Reconstruct parameter space
from heterodyne.config.parameter_space import ParameterSpace
ps_dict: dict[str, Any] = shard_data.get("parameter_space_dict") or {}
try:
parameter_space = ParameterSpace.from_config(ps_dict)
except Exception as exc: # noqa: BLE001 — fall back to defaults
logger.warning(
"ParameterSpace.from_config failed (%s); falling back to defaults. "
"Worker priors/bounds may differ from parent config.",
exc,
)
parameter_space = ParameterSpace()
# ------------------------------------------------------------------
# Build NumPyro model (inline closure over shard-local arrays)
# ------------------------------------------------------------------
import numpyro
import numpyro.distributions as dist
# varying_physics_names excludes scaling params (contrast, offset) by
# design — they are fixed scalar args to compute_c2_heterodyne, not NUTS
# latent sites. Using this property instead of varying_names prevents the
# opaque "tuple index out of range" pytree crash when the caller's
# ParameterSpace was built from an NLSQ result that varies all 16 params.
varying_names = parameter_space.varying_physics_names
fixed_values = parameter_space.get_initial_array()
if not varying_names:
raise ValueError(
f"Shard {shard_idx}: ParameterSpace has no active physics "
"parameters. Ensure at least one of the 14 physics params is "
"set to vary."
)
_model_sites: frozenset[str] = frozenset(ALL_PARAM_NAMES)
worker_logger.debug(
"Shard %d: %d physics params vary: %s",
shard_idx,
len(varying_names),
", ".join(varying_names),
)
# Warm-start init params for NumPyro.
# Restrict to _model_sites so scaling params from the NLSQ warm-start
# dict never leak into NUTS init_params.
#
# NumPyro ≥0.21 (mcmc.py:683) checks jnp.shape(v)[0] == num_chains to
# decide whether values are pre-batched. 0-d scalars (shape=()) cause
# IndexError: tuple index out of range. Broadcasting each value to
# (num_chains,) satisfies the check and replicates the same start point
# across all chains.
#
# Values at the exact bound edge (e.g. alpha_sample=-2.0 with
# low=-2.0) map to -inf in the unconstrained transform, making the
# initial log-prob undefined. Clipping by _INIT_BOUND_EPS keeps them
# strictly inside the support.
_INIT_BOUND_EPS: float = 1e-6
_num_chains: int = config.num_chains
def _safe_init_val(name: str, raw: float) -> jnp.ndarray:
if name in parameter_space.bounds:
lo, hi = parameter_space.bounds[name]
raw = float(np.clip(raw, lo + _INIT_BOUND_EPS, hi - _INIT_BOUND_EPS))
return jnp.broadcast_to(jnp.asarray(raw), (_num_chains,))
init_params: dict[str, jnp.ndarray] = {}
if initial_values is not None:
for k, v in initial_values.items():
if k in varying_names and k in _model_sites:
init_params[k] = _safe_init_val(k, float(v))
else:
# No NLSQ warm-start: seed every physics param from registry
# defaults so init_to_median never tries to sample BetaScaled or
# TruncatedNormal distributions without a valid PRNG key.
# NumPyro ≥0.21 asserts is_prng_key(key) inside .sample(), which
# fires for any site absent from init_params when those distributions
# are used as priors (f0, contrast use BetaScaled by default).
ps_vals = parameter_space.values
for k in varying_names:
if k in ps_vals and k in _model_sites:
init_params[k] = _safe_init_val(k, float(ps_vals[k]))
# Seed sigma so init_to_median never samples HalfNormal with a
# non-key argument (NumPyro ≥0.21 asserts is_prng_key(key)).
init_params["sigma"] = jnp.broadcast_to(
jnp.asarray(noise_scale), (_num_chains,)
)
# CMC prior tempering: widen prior std by prior_width_mult = sqrt(num_shards).
# Received via shared_kwargs from run_shards() when doing sharded CMC.
prior_width_mult: float = float(shard_data.get("prior_width_multiplier", 1.0))
nlsq_uncertainties_dict: dict[str, Any] = (
shard_data.get("nlsq_uncertainties") or {}
)
nlsq_prior_width_factor: float = float(
shard_data.get("nlsq_prior_width_factor", 2.0)
)
tempered_priors_dict: dict[str, Any] = {}
if prior_width_mult != 1.0:
from heterodyne.optimization.cmc.priors import (
build_default_priors as _build_default_priors,
)
from heterodyne.optimization.cmc.priors import (
temper_priors as _temper_priors,
)
# Derive num_shards from multiplier (prior_width_mult = sqrt(num_shards))
num_shards_est = max(2, round(prior_width_mult**2))
# If NLSQ point estimates + uncertainties are available, build
# NLSQ-informed priors (TruncatedNormal centered on initial value,
# scale = unc * width_factor) BEFORE tempering. This preserves the
# NLSQ posterior contraction across shards and matches fit_cmc_jax
# (non-sharded) prior behavior.
base_priors_for_temper: dict[str, Any] = {}
if nlsq_uncertainties_dict and initial_values:
import numpyro.distributions as _dist
for _name in varying_names:
_low, _high = parameter_space.bounds[_name]
_center = (
float(initial_values[_name])
if _name in initial_values
else float(parameter_space.values[_name])
)
if _name in nlsq_uncertainties_dict:
_scale = (
float(nlsq_uncertainties_dict[_name])
* nlsq_prior_width_factor
)
_scale = max(_scale, 1e-10)
else:
# Fall back to the registry prior for this parameter
base_priors_for_temper[_name] = parameter_space.priors[
_name
].to_numpyro(_name)
continue
base_priors_for_temper[_name] = _dist.TruncatedNormal(
loc=_center,
scale=_scale,
low=float(_low),
high=float(_high),
)
else:
# Codex S1: parent forwards use_log_space_priors via shared_kwargs;
# default True for legacy callers that didn't set it.
base_priors_for_temper = _build_default_priors(
parameter_space,
use_log_space_priors=bool(
shard_data.get("use_log_space_priors", True)
),
)
tempered_priors_dict = _temper_priors(
base_priors_for_temper, num_shards_est
)
# Homodyne parity: sample sigma inside the model with HalfNormal prior
# tempered by sqrt(num_shards). prior_width_mult already equals
# sqrt(num_shards) (see line 773); fall back to 1.0 for non-sharded use.
_shard_sigma_scale = float(noise_scale) * 1.5 * max(prior_width_mult, 1.0)
def _shard_model() -> None:
"""NumPyro model for one CMC shard (14-parameter heterodyne)."""
params = jnp.asarray(fixed_values)
for i, name in enumerate(ALL_PARAM_NAMES):
if name in varying_names:
if name in tempered_priors_dict:
numpyro_dist = tempered_priors_dict[name]
else:
prior = parameter_space.priors[name]
numpyro_dist = prior.to_numpyro(name)
param = numpyro.sample(name, numpyro_dist)
params = params.at[i].set(param)
# Compute 14-parameter heterodyne c2 prediction.
# Two paths based on shard wire format (set in fit_cmc_sharded):
#
# _shard_grid present (element-wise, random strategy):
# compute_c2_elementwise → shape (n_pairs,) — matches flat c2_data.
# No N×N matrix allocation; O(n_pairs) memory.
#
# _shard_grid absent (meshgrid, contiguous strategy):
# compute_c2_heterodyne → shape (shard_n, shard_n) — matches 2-D c2_data.
#
# All variables are closure-captured from the outer scope; ruff F821
# cannot resolve closures statically.
# CONTRACT: when the unpacked shard carries t1/t2/time_grid, the
# element-wise path MUST be taken (het_7221ba99 regression guard).
# Pinned by tests/regression/test_cmc_shard_shape_roundtrip.py::
# TestShardDispatch::test_t1_t2_time_grid_preserved_through_roundtrip
# (proxy test — the dispatch happens inside this NumPyro closure
# so a direct mock would require >5 lines of fixture scaffolding;
# if you refactor this branch, update that proxy test too).
if _shard_grid is not None: # noqa: F821
c2_model = _compute_c2_elementwise( # noqa: F821
params,
_shard_grid, # noqa: F821
q_val,
dt_val,
phi_angle,
contrast,
offset,
)
else:
c2_model = compute_c2_heterodyne(
params,
t_jax, # noqa: F821
q_val,
dt_val,
phi_angle,
contrast,
offset,
)
# Track non-finite predictions so dashboards can flag bad shards.
n_nan = jnp.sum(~jnp.isfinite(c2_model))
numpyro.deterministic("n_numerical_issues", n_nan)
# Sample sigma — full homodyne CMC parity.
shard_sigma = numpyro.sample(
"sigma", dist.HalfNormal(scale=_shard_sigma_scale)
)
numpyro.sample( # noqa: F821 — closure variable c2_jax
"obs",
dist.Normal(c2_model, shard_sigma),
obs=c2_jax, # noqa: F821
)
# ------------------------------------------------------------------
# Configure NUTS / MCMC
# ------------------------------------------------------------------
_init_map: dict[str, Any] = {
"init_to_median": numpyro_init.init_to_median,
"init_to_sample": numpyro_init.init_to_sample,
"init_to_value": numpyro_init.init_to_value,
}
init_strategy_name = getattr(config, "init_strategy", "init_to_median")
init_factory = _init_map.get(init_strategy_name, numpyro_init.init_to_median)
kernel = NUTS(
_shard_model,
target_accept_prob=config.target_accept_prob,
max_tree_depth=config.max_tree_depth,
dense_mass=config.dense_mass,
init_strategy=init_factory(),
)
mcmc = MCMC(
kernel,
num_warmup=config.num_warmup,
num_samples=config.num_samples,
num_chains=config.num_chains,
chain_method="sequential", # single process = sequential chains
progress_bar=False,
)
# Capture homodyne-parity extra fields: divergence/energy plus per-step
# accept_prob, num_steps (proxy for NUTS tree depth via log2), and
# potential_energy. These power downstream diagnostics.
mcmc.run(
rng_key,
init_params=init_params,
extra_fields=_NUTS_EXTRA_FIELDS,
)
samples_raw: dict[str, Any] = mcmc.get_samples()
samples_np: dict[str, np.ndarray] = {
k: np.array(v) for k, v in samples_raw.items()
}
extra_raw: dict[str, Any] = mcmc.get_extra_fields()
extra_np: dict[str, np.ndarray] = {k: np.array(v) for k, v in extra_raw.items()}
diverging = extra_np.get("diverging")
num_divergent = int(np.sum(diverging)) if diverging is not None else 0
duration = time.perf_counter() - start_time
# Free large JAX arrays before serialisation to reduce peak memory
del c2_jax, sigma_jax, t_jax, extra_raw, samples_raw
mcmc = None
divergence_str = f", divergences: {num_divergent}" if num_divergent > 0 else ""
worker_logger.info(
"Shard %d completed in %.2fs: %d samples/chain x %d chains%s",
shard_idx,
duration,
config.num_samples,
config.num_chains,
divergence_str,
)
if num_divergent > 0:
worker_logger.warning(
"Shard %d had %d divergent transitions", shard_idx, num_divergent
)
return {
"type": "result",
"success": True,
"shard_idx": shard_idx,
"samples": samples_np,
"param_names": list(samples_np.keys()),
"n_chains": config.num_chains,
"n_samples": config.num_samples,
"extra_fields": extra_np,
"duration": duration,
"stats": {
"num_divergent": num_divergent,
"n_warmup": config.num_warmup,
"n_samples": config.num_samples,
},
}
except Exception as exc: # noqa: BLE001 — top-level worker; must convert any crash to result dict
import traceback as _tb
duration = time.perf_counter() - start_time
error_str = str(exc).lower()
if "nan" in error_str or "inf" in error_str or "singular" in error_str:
error_category = "numerical"
elif "convergence" in error_str or "diverge" in error_str:
error_category = "convergence"
elif "memory" in error_str:
error_category = "memory_error"
elif (
"tuple index" in error_str
or "index out of range" in error_str
or ("not a latent site" in error_str or "parameter_space" in error_str)
):
error_category = "config_error"
else:
error_category = "sampling"
log_exception(
worker_logger,
exc,
context={
"shard_idx": shard_idx,
"duration_s": round(duration, 2),
"error_category": error_category,
"n_points": n_points,
},
)
return {
"type": "result",
"success": False,
"shard_idx": shard_idx,
"error": str(exc),
"error_category": error_category,
"traceback": _tb.format_exc(),
"duration": duration,
}
finally:
stop_hb.set()
hb_thread.join(timeout=1)
# ---------------------------------------------------------------------------
# Top-level worker entry point (module-level for spawn pickling)
# ---------------------------------------------------------------------------
def _run_shard_worker_with_queue(
shard_idx: int,
shard_ref: dict[str, Any],
config_ref: dict[str, Any],
shared_kwargs_ref: dict[str, Any],
initial_values_ref: dict[str, Any] | None,
ps_ref: dict[str, Any],
threads_per_worker: int,
result_queue: mp.Queue,
rng_key_tuple: tuple[int, ...] | None = None,
) -> None:
"""Entry point for each per-shard spawned process.
Reconstructs all shared data from shared memory, calls
:func:`_run_shard_worker`, and puts the result dict on ``result_queue``.
Wraps the entire body in try/except so initialisation crashes are
captured and reported to the parent rather than silently lost.
This function must be defined at module level so that Python's spawn
mechanism can pickle it when creating child processes.
Args:
shard_idx: Shard index.
shard_ref: Packed shared-memory reference for per-shard arrays.
config_ref: Shared-memory reference for ``CMCConfig`` dict.
shared_kwargs_ref: Shared-memory reference for shared scalar kwargs.
initial_values_ref: Shared-memory reference for NLSQ warm-start
values, or ``None``.
ps_ref: Shared-memory reference for parameter-space dict.
threads_per_worker: Per-worker thread budget (sets OMP/MKL env).
result_queue: Multiprocessing queue for result delivery.
rng_key_tuple: Pre-generated PRNG key raw ``uint32`` tuple.
"""
try:
_init_worker_jax(
threads_per_worker=threads_per_worker,
num_chains=int(os.environ.get("HETERODYNE_CMC_NUM_CHAINS", "4")),
)
shard_data = _load_shared_shard_data(shard_ref)
config_dict = _load_shared_dict(config_ref)
shared_kwargs = _load_shared_dict(shared_kwargs_ref)
initial_values: dict[str, Any] | None = (
_load_shared_dict(initial_values_ref)
if initial_values_ref is not None
else None
)
ps_dict = _load_shared_dict(ps_ref)
# Merge shared scalars and parameter-space dict into shard_data
shard_data.update(shared_kwargs)
shard_data["parameter_space_dict"] = ps_dict
result = _run_shard_worker(
shard_idx=shard_idx,
shard_data=shard_data,
config_dict=config_dict,
initial_values=initial_values,
result_queue=result_queue,
rng_key_tuple=rng_key_tuple,
)
except Exception as exc: # noqa: BLE001 — top-level worker; must convert any crash to result dict
import traceback as _tb
result = {
"type": "result",
"success": False,
"shard_idx": shard_idx,
"error": f"Worker initialisation failed: {exc}",
"error_category": "init_crash",
"traceback": _tb.format_exc(),
"duration": 0.0,
}
# Try hard to deliver the result; if the queue is full or closed, log a
# hard ERROR so the parent's shard-reconciliation path can treat this
# shard as failed instead of hanging on queue.get(). Silent drops
# previously caused the parent either to wait forever or to misclassify
# the shard as timed-out, producing biased aggregate posteriors with no
# diagnostic signal. See Codex review 2026-05-22.
try:
result_queue.put(result, block=True, timeout=30.0)
except Exception as _delivery_exc: # noqa: BLE001 — worker exit boundary
try:
worker_logger = get_logger(f"heterodyne.cmc.worker.shard_{shard_idx}")
worker_logger.error(
"Shard %d: failed to deliver result to parent queue "
"(error=%s). Parent will treat this shard as failed.",
shard_idx,
_delivery_exc,
)
except Exception: # noqa: BLE001 — logging must not raise from worker exit
pass
# ---------------------------------------------------------------------------
# LPTScheduler
# ---------------------------------------------------------------------------
[docs]
class LPTScheduler:
"""Longest Processing Time scheduler for load balancing across cores.
Assigns shards to workers based on estimated computation time using the
LPT (Longest Processing Time first) heuristic. The highest-cost shards
are dispatched first so that the remaining shards finishing last are the
cheapest, minimising overall tail latency.
This is a simple greedy scheduler; it does not account for real-time
feedback about actual execution durations.
Attributes:
n_workers: Number of parallel workers.
_shard_order: Deque of shard indices sorted by descending cost.
"""
[docs]
def __init__(
self,
shard_costs: list[float],
n_workers: int,
) -> None:
"""Initialise the LPT scheduler.
Args:
shard_costs: Estimated cost (positive float) per shard. Higher
is more expensive.
n_workers: Number of parallel workers.
"""
if not shard_costs:
raise ValueError("shard_costs must be non-empty")
self.n_workers = max(1, n_workers)
self._shard_order: deque[int] = deque(
sorted(range(len(shard_costs)), key=lambda i: shard_costs[i], reverse=True)
)
[docs]
@classmethod
def from_shard_data(
cls,
shard_data_list: list[dict[str, Any]],
n_workers: int,
n_params: int = _N_PARAMS_HETERODYNE,
n_samples: int = 1000,
) -> LPTScheduler:
"""Build an :class:`LPTScheduler` from raw shard data dicts.
Cost is estimated via :func:`_estimate_shard_time` for each shard.
Args:
shard_data_list: Shard dicts with ``"c2_data"`` and
``"noise_scale"`` keys.
n_workers: Number of parallel workers.
n_params: Number of model parameters (default: 14).
n_samples: Expected total MCMC draws per shard.
Returns:
Configured :class:`LPTScheduler`.
"""
costs: list[float] = []
for sd in shard_data_list:
c2 = sd.get("c2_data")
n_data = len(c2) if c2 is not None else 1
costs.append(_estimate_shard_time(n_data, n_params, n_samples))
return cls(shard_costs=costs, n_workers=n_workers)
[docs]
def next_shard(self) -> int | None:
"""Pop and return the next shard index to dispatch.
Returns:
Next shard index (highest remaining cost), or ``None`` when
all shards have been dispatched.
"""
if not self._shard_order:
return None
return self._shard_order.popleft()
[docs]
def remaining(self) -> int:
"""Return the number of shards not yet dispatched."""
return len(self._shard_order)
[docs]
def as_deque(self) -> deque[int]:
"""Return the internal order deque (consumed by dispatch loop)."""
return self._shard_order
# ---------------------------------------------------------------------------
# MultiprocessingBackend
# ---------------------------------------------------------------------------
[docs]
class MultiprocessingBackend(CMCBackend):
"""CMC backend that parallelises NUTS across shards via spawned processes.
Each shard runs as an independent Python process so that JAX is
initialised fresh per shard — avoiding the shared-state issues that
arise when forking a process that already has a JAX runtime loaded.
Shared data (config, parameter space, initial values, per-shard arrays)
is placed in ``SharedDataManager`` once in the parent and accessed via
``_load_shared_*`` in each child, minimising serialisation overhead
through spawn.
The :meth:`run` method provides the standard single-shard
:class:`CMCBackend` contract (sequential chain execution, no subprocess
overhead). For multi-shard CMC, use :meth:`run_shards`, which
orchestrates the full parallel dispatch loop.
Attributes:
n_workers: Number of concurrent worker processes.
spawn_method: Multiprocessing start method (always ``"spawn"``).
_shared_mgr: Active :class:`SharedDataManager` during
:meth:`run_shards`; ``None`` otherwise.
"""
[docs]
def __init__(
self,
n_workers: int | None = None,
spawn_method: str = "spawn",
) -> None:
"""Initialise the multiprocessing backend.
Args:
n_workers: Number of worker processes. Defaults to the
estimated physical core count, capped to avoid
oversubscription.
spawn_method: Process start method. Must be ``"spawn"`` for
JAX safety. ``"fork"`` is explicitly unsupported.
Raises:
ValueError: If ``spawn_method="fork"`` is requested.
"""
if spawn_method == "fork":
raise ValueError(
"MultiprocessingBackend does not support spawn_method='fork'. "
"JAX cannot be safely shared across forked processes. "
"Use spawn_method='spawn' (default)."
)
if n_workers is None:
n_workers = max(1, _get_physical_cores())
else:
n_workers = min(n_workers, max(1, _get_physical_cores()))
self.n_workers: int = max(1, n_workers)
self.spawn_method: str = spawn_method
self._shared_mgr: SharedDataManager | None = None
logger.debug(
"MultiprocessingBackend: n_workers=%d, spawn_method=%s",
self.n_workers,
self.spawn_method,
)
# ------------------------------------------------------------------
# CMCBackend abstract methods
# ------------------------------------------------------------------
[docs]
def run(
self,
model: Callable[..., Any],
config: CMCConfig,
rng_key: jnp.ndarray,
init_params: dict[str, jnp.ndarray] | None = None,
) -> dict[str, Any]:
"""Run NUTS sampling for a single shard (standard CMCBackend contract).
For multi-shard CMC, call :meth:`run_shards` instead. This method
provides API parity with :class:`CPUBackend` and :class:`PjitBackend`
using sequential chain execution and no subprocess overhead.
Args:
model: NumPyro model function.
config: CMC configuration.
rng_key: JAX PRNG key.
init_params: Optional per-chain initial values.
Returns:
Dictionary of posterior samples from all chains.
Raises:
RuntimeError: If MCMC sampling fails.
"""
from numpyro.infer import MCMC, NUTS
from numpyro.infer import initialization as numpyro_init
logger.info(
"MultiprocessingBackend.run: single-shard sequential mode, "
"%d chains (%d warmup, %d samples)",
config.num_chains,
config.num_warmup,
config.num_samples,
)
_init_map: dict[str, Any] = {
"init_to_median": numpyro_init.init_to_median,
"init_to_sample": numpyro_init.init_to_sample,
"init_to_value": numpyro_init.init_to_value,
}
init_factory = _init_map.get(
getattr(config, "init_strategy", "init_to_median"),
numpyro_init.init_to_median,
)
kernel = NUTS(
model,
target_accept_prob=config.target_accept_prob,
max_tree_depth=config.max_tree_depth,
dense_mass=config.dense_mass,
init_strategy=init_factory(),
)
mcmc = MCMC(
kernel,
num_warmup=config.num_warmup,
num_samples=config.num_samples,
num_chains=config.num_chains,
chain_method="sequential",
progress_bar=True,
)
# Capture homodyne-parity extra fields: divergence/energy plus per-step
# accept_prob, num_steps (proxy for NUTS tree depth via log2), and
# potential_energy. These power downstream diagnostics.
mcmc.run(
rng_key,
init_params=init_params,
extra_fields=_NUTS_EXTRA_FIELDS,
)
samples = mcmc.get_samples()
logger.info("MultiprocessingBackend.run: sampling complete")
return dict(samples)
[docs]
def get_capabilities(self) -> BackendCapabilities:
"""Return multiprocessing backend capabilities.
Returns:
:class:`BackendCapabilities` indicating sharding support,
parallel shards equal to ``n_workers``.
"""
return BackendCapabilities(
supports_sharding=True,
supports_parallel_chains=True,
max_parallel_shards=self.n_workers,
)
[docs]
def validate_resources(self) -> None:
"""Check that CPU resources and multiprocessing are available.
Raises:
RuntimeError: If no JAX CPU device is found or if the
``multiprocessing`` module cannot create a spawn context.
"""
import jax
devices = jax.devices("cpu")
if not devices:
raise RuntimeError("MultiprocessingBackend: no JAX CPU devices found.")
try:
mp.get_context(self.spawn_method)
except ValueError as exc:
raise RuntimeError(
f"MultiprocessingBackend: cannot create '{self.spawn_method}' "
f"multiprocessing context: {exc}"
) from exc
logger.debug(
"MultiprocessingBackend.validate_resources: %d CPU device(s), n_workers=%d",
len(devices),
self.n_workers,
)
[docs]
def estimate_memory(
self,
n_data: int,
n_params: int,
n_chains: int,
) -> float:
"""Estimate peak memory for all concurrent workers combined.
Conservative upper bound: each worker holds one chain's live state
(params + momentum + gradients) plus the data buffer. Workers run
chains sequentially so ``n_chains`` does not multiply within a
single worker.
Args:
n_data: Number of data points per shard.
n_params: Number of model parameters.
n_chains: Number of MCMC chains per shard (not used for per-worker
estimate; included for API uniformity).
Returns:
Estimated peak memory in gigabytes.
"""
state_bytes = 3 * n_params * _BYTES_PER_FLOAT64
data_bytes = n_data * _BYTES_PER_FLOAT64
per_worker_bytes = (state_bytes + data_bytes) * _CPU_MEMORY_OVERHEAD_FACTOR
total_bytes = per_worker_bytes * self.n_workers
return total_bytes / _BYTES_PER_GB
[docs]
def cleanup(self) -> None:
"""Release shared memory and any other resources.
Idempotent — safe to call multiple times.
"""
if self._shared_mgr is not None:
self._shared_mgr.cleanup()
self._shared_mgr = None
logger.debug("MultiprocessingBackend.cleanup: complete")
# ------------------------------------------------------------------
# CMC sharded execution
# ------------------------------------------------------------------
[docs]
def run_shards(
self,
shards: list[dict[str, Any]],
config: CMCConfig,
initial_values: dict[str, Any] | None = None,
parameter_space: Any | None = None,
prior_width_multiplier: float = 1.0,
nlsq_uncertainties: dict[str, float] | None = None,
nlsq_prior_width_factor: float = 2.0,
progress_bar: bool = True,
) -> list[dict[str, Any]]:
"""Run NUTS in parallel across all CMC shards.
Orchestrates the full parallel dispatch loop:
1. Allocate shared memory for config, parameter space, initial
values, and per-shard arrays.
2. Pre-generate all PRNG keys in the parent process.
3. Dispatch shards to worker processes in LPT order.
4. Drain the result queue with adaptive polling.
5. Enforce per-shard and heartbeat timeouts.
6. Validate and return successful shard results.
Args:
shards: List of shard dicts. Each must contain at minimum
``c2_data`` (numpy array). Optional keys: ``sigma``,
``t``, ``weights``, ``noise_scale``, ``q``, ``dt``,
``phi_angle``, ``contrast``, ``offset``,
``reparam_config_dict``.
config: CMC configuration with NUTS hyperparameters and
timeout settings.
initial_values: Optional NLSQ warm-start values shared
across all shards.
parameter_space: Optional :class:`ParameterSpace` instance.
Its internal config dict is serialised into shared memory.
progress_bar: Whether to show a tqdm progress bar.
Returns:
List of validated successful result dicts, one per succeeded
shard. Each dict contains ``shard_idx``, ``samples``,
``n_chains``, ``n_samples``, ``param_names``, ``extra_fields``,
``duration``, and ``stats``.
Raises:
ValueError: If ``shards`` is empty.
RuntimeError: If all shards fail, or if the success rate falls
below ``config.min_success_rate``.
"""
if not shards:
raise ValueError("run_shards: shards list must be non-empty")
n_shards = len(shards)
actual_workers = min(self.n_workers, n_shards)
total_threads = mp.cpu_count() or 1
threads_per_worker = _compute_threads_per_worker(total_threads, actual_workers)
run_logger = with_context(
logger,
run=getattr(config, "run_id", None),
backend="multiprocessing",
)
run_logger.info(
"run_shards: %d shards, %d workers, %d threads/worker",
n_shards,
actual_workers,
threads_per_worker,
)
run_logger.info(
"Per-shard timeout: %ds, heartbeat timeout: %ds",
config.per_shard_timeout,
config.heartbeat_timeout,
)
# ---------------------------------------------------------- #
# Serialise shared data into shared memory
# ---------------------------------------------------------- #
config_dict = config.to_dict()
if parameter_space is not None and hasattr(parameter_space, "_config_dict"):
ps_dict: dict[str, Any] = parameter_space._config_dict
elif parameter_space is not None:
ps_dict = parameter_space.to_config()
run_logger.debug(
"ParameterSpace._config_dict absent; serialized via to_config() "
"(%d varying params)",
len(parameter_space.varying_names),
)
else:
ps_dict = {}
run_logger.warning(
"ParameterSpace not provided; workers will use default parameter "
"bounds (may produce unconstrained proposals)"
)
# Shared scalars extracted from the first shard (same for all shards
# in a homogeneous CMC split).
_first = shards[0]
shared_kwargs: dict[str, Any] = {
"q": _first.get("q", 1.0),
"dt": _first.get("dt", 1e-3),
"phi_angle": _first.get("phi_angle", 0.0),
"contrast": _first.get("contrast", 1.0),
"offset": _first.get("offset", 1.0),
"n_phi": _first.get("n_phi", 1),
"reparam_config_dict": _first.get("reparam_config_dict"),
"prior_width_multiplier": float(prior_width_multiplier),
# NLSQ-informed prior payload (plain dict[str, float], cross-process safe).
# When non-empty, workers center their tempered TruncatedNormal priors on
# initial_values[name] with scale = nlsq_uncertainties[name] * nlsq_prior_width_factor.
"nlsq_uncertainties": dict(nlsq_uncertainties)
if nlsq_uncertainties
else {},
"nlsq_prior_width_factor": float(nlsq_prior_width_factor),
# Codex S1: forward use_log_space_priors so the worker's
# tempered-default branch matches the parent's prior factory.
"use_log_space_priors": bool(getattr(config, "use_log_space_priors", True)),
}
# Build per-shard numpy dicts for shared memory packing
shard_data_list: list[dict[str, Any]] = []
for shard in shards:
c2 = shard.get("c2_data")
shard_data_list.append(
{
"c2_data": np.asarray(c2) if c2 is not None else None,
"sigma": (
np.asarray(shard["sigma"])
if shard.get("sigma") is not None
else None
),
"t": (
np.asarray(shard["t"]) if shard.get("t") is not None else None
),
"t1": (
np.asarray(shard["t1"]) if shard.get("t1") is not None else None
),
"t2": (
np.asarray(shard["t2"]) if shard.get("t2") is not None else None
),
"time_grid": (
np.asarray(shard["time_grid"])
if shard.get("time_grid") is not None
else None
),
"weights": (
np.asarray(shard["weights"])
if shard.get("weights") is not None
else None
),
"noise_scale": float(shard.get("noise_scale", 0.1)),
}
)
shared_mgr = SharedDataManager()
self._shared_mgr = shared_mgr
try:
shared_config_ref = shared_mgr.create_shared_dict("config", config_dict)
shared_ps_ref = shared_mgr.create_shared_dict("ps", ps_dict)
shared_kwargs_ref = shared_mgr.create_shared_dict("kwargs", shared_kwargs)
shared_iv_ref: dict[str, Any] | None = None
if initial_values is not None:
shared_iv_ref = shared_mgr.create_shared_dict(
"init_vals", initial_values
)
shared_shard_refs = shared_mgr.create_shared_shard_arrays(shard_data_list)
except Exception: # noqa: BLE001 — cleanup-and-reraise; must run shared_mgr.cleanup() on any failure
shared_mgr.cleanup()
self._shared_mgr = None
raise
# Free numpy copies after they are copied into shared memory
del shard_data_list
# Sentinel variables (defined before try so finally never NameErrors)
_saved_env: dict[str, str | None] = {}
active_processes: dict[int, tuple[mp.Process, float]] = {}
pbar = None
try:
run_logger.debug(
"Shared memory allocated: %d blocks",
len(shared_mgr._shared_blocks),
)
# Pre-generate PRNG keys in parent (batch optimisation)
seed = config.seed if config.seed is not None else 42
shard_keys = _generate_shard_keys(n_shards, seed=seed)
run_logger.debug("Pre-generated %d PRNG keys (seed=%d)", n_shards, seed)
ctx = mp.get_context(self.spawn_method)
result_queue: mp.Queue = ctx.Queue()
# Temporarily override env for spawned workers to prevent
# thread oversubscription inherited from the parent process.
_worker_env_overrides: dict[str, str] = {
"OMP_NUM_THREADS": str(threads_per_worker),
"MKL_NUM_THREADS": str(threads_per_worker),
"OPENBLAS_NUM_THREADS": str(threads_per_worker),
"VECLIB_MAXIMUM_THREADS": str(threads_per_worker),
# Pass num_chains so workers set XLA device count dynamically
"HETERODYNE_CMC_NUM_CHAINS": str(config.num_chains),
}
_worker_env_clear = ["OMP_PROC_BIND", "OMP_PLACES"]
for key in _worker_env_clear:
_saved_env[key] = os.environ.pop(key, None)
for key, val in _worker_env_overrides.items():
_saved_env[key] = os.environ.get(key)
os.environ[key] = val
# LPT scheduling: dispatch highest-cost shards first
pending_shards = _compute_lpt_schedule(
[
{
"c2_data": shards[i].get("c2_data"),
"noise_scale": float(shards[i].get("noise_scale", 0.1)),
}
for i in range(n_shards)
]
)
if n_shards > 1:
run_logger.debug("LPT dispatch order: %s", list(pending_shards))
results: list[dict[str, Any]] = []
completed_count = 0
recorded_shards: set[int] = set()
last_heartbeat: dict[int, float] = {}
success_count = 0
# Early-abort: if >50% of first N shards fail, terminate early
early_abort_threshold = 0.5
early_abort_sample_size = min(10, n_shards)
failure_categories: dict[str, int] = {
"timeout": 0,
"heartbeat_timeout": 0,
"crash": 0,
"numerical": 0,
"convergence": 0,
"memory_error": 0,
"config_error": 0,
"sampling": 0,
"init_crash": 0,
"unknown": 0,
}
early_abort_triggered = False
pbar = tqdm(
total=n_shards,
desc="CMC shards",
disable=not progress_bar,
unit="shard",
position=0,
leave=True,
dynamic_ncols=True,
)
pbar.set_postfix_str("starting...")
pbar.refresh()
start_time = time.time()
poll_interval_min = 0.5
poll_interval_max = 5.0
poll_interval = poll_interval_min
last_completion_time = start_time
status_log_interval = 300.0 # parent status log every 5 minutes
last_status_log = start_time
shards_launched = 0
per_shard_timeout = config.per_shard_timeout
while completed_count < n_shards:
# -------------------------------------------------- #
# Drain the result queue
# -------------------------------------------------- #
while True:
try:
message: dict[str, Any] = result_queue.get_nowait()
except queue.Empty:
break
except Exception as _qexc: # noqa: BLE001 — best-effort queue drain; any IPC error breaks the loop
run_logger.warning("Queue read error: %s", _qexc)
break
msg_type = message.get("type")
msg_shard_idx = message.get("shard_idx")
if msg_type == "heartbeat" and msg_shard_idx is not None:
last_heartbeat[msg_shard_idx] = time.time()
continue
if msg_type == "result" or message.get("success") is not None:
if (
msg_shard_idx is not None
and msg_shard_idx in recorded_shards
):
run_logger.debug(
"Ignoring duplicate result for shard %d",
msg_shard_idx,
)
continue
results.append(message)
if msg_shard_idx is not None:
recorded_shards.add(msg_shard_idx)
completed_count += 1
pbar.update(1)
last_completion_time = time.time()
poll_interval = poll_interval_min
if message.get("success"):
success_count += 1
pbar.set_postfix(
shard=message.get("shard_idx", "?"),
time=f"{message.get('duration', 0):.1f}s",
)
else:
category = message.get("error_category", "unknown")
if category in failure_categories:
failure_categories[category] += 1
else:
failure_categories["unknown"] += 1
pbar.set_postfix(
shard=message.get("shard_idx", "?"),
status="failed",
)
# Early-abort check after first N completions
if (
not early_abort_triggered
and completed_count >= early_abort_sample_size
and completed_count <= early_abort_sample_size + 2
):
total_failures = sum(failure_categories.values())
failure_rate = total_failures / completed_count
if failure_rate > early_abort_threshold:
early_abort_triggered = True
run_logger.error(
"EARLY ABORT: %.1f%% failure rate in first "
"%d shards exceeds %.0f%% threshold. "
"Failure breakdown: %s",
failure_rate * 100,
completed_count,
early_abort_threshold * 100,
failure_categories,
)
pending_shards.clear()
for _idx, (_proc, _) in list(active_processes.items()):
run_logger.info(
"Terminating shard %d (early abort)",
_idx,
)
_proc.terminate()
_proc.join(timeout=2)
if _proc.is_alive():
_proc.kill()
_proc.join(timeout=1)
active_processes.pop(_idx, None)
if msg_shard_idx in active_processes:
_proc, _ = active_processes.pop(msg_shard_idx)
if _proc.is_alive():
_proc.join(timeout=1)
continue
if run_logger.isEnabledFor(logging.DEBUG):
run_logger.debug(
"Ignoring unexpected queue message: %s", message
)
# -------------------------------------------------- #
# Launch new processes up to capacity
# -------------------------------------------------- #
while len(active_processes) < actual_workers and pending_shards:
next_shard_idx = pending_shards.popleft()
process = ctx.Process(
target=_run_shard_worker_with_queue,
args=(
next_shard_idx,
shared_shard_refs[next_shard_idx],
shared_config_ref,
shared_kwargs_ref,
shared_iv_ref,
shared_ps_ref,
threads_per_worker,
result_queue,
shard_keys[next_shard_idx],
),
)
process.start()
_now = time.time()
active_processes[next_shard_idx] = (process, _now)
last_heartbeat[next_shard_idx] = _now
shards_launched += 1
# -------------------------------------------------- #
# Check process health (timeout / heartbeat / crash)
# -------------------------------------------------- #
for _idx, (_process, _proc_start) in list(active_processes.items()):
if _idx in recorded_shards:
del active_processes[_idx]
continue
_now = time.time()
_elapsed = _now - _proc_start
_last_active = last_heartbeat.get(_idx, _proc_start)
_inactive = _now - _last_active
if not _process.is_alive():
_process.join(timeout=1)
_exit_code = _process.exitcode
del active_processes[_idx]
if _idx not in recorded_shards:
if _exit_code is not None and _exit_code < 0:
import signal as _signal
try:
_sig_name = _signal.Signals(-_exit_code).name
except ValueError:
_sig_name = str(-_exit_code)
_err = (
f"Process killed by signal {_sig_name} "
f"(exit_code={_exit_code})"
)
elif _exit_code is not None and _exit_code > 0:
_err = (
f"Process exited with error "
f"(exit_code={_exit_code})"
)
else:
_err = "Process exited without returning a result"
results.append(
{
"type": "result",
"success": False,
"shard_idx": _idx,
"error": _err,
"error_category": "crash",
"duration": _elapsed,
}
)
recorded_shards.add(_idx)
failure_categories["crash"] += 1
completed_count += 1
pbar.update(1)
pbar.set_postfix(shard=_idx, status="no-result")
elif _elapsed > per_shard_timeout:
run_logger.warning(
"Shard %d exceeded runtime limit: %.0fs "
"(limit: %ds), terminating (pid=%s)",
_idx,
_elapsed,
per_shard_timeout,
_process.pid,
)
_process.terminate()
_process.join(timeout=5)
if _process.is_alive():
_process.kill()
_process.join(timeout=2)
del active_processes[_idx]
if _idx not in recorded_shards:
results.append(
{
"type": "result",
"success": False,
"shard_idx": _idx,
"error": (
f"Runtime timeout after {_elapsed:.0f}s "
f"(limit: {per_shard_timeout}s)"
),
"error_category": "timeout",
"duration": _elapsed,
}
)
recorded_shards.add(_idx)
failure_categories["timeout"] += 1
completed_count += 1
pbar.update(1)
pbar.set_postfix(shard=_idx, status="timeout")
elif _inactive > config.heartbeat_timeout:
run_logger.warning(
"Shard %d unresponsive for %.0fs "
"(heartbeat timeout: %ds), terminating (pid=%s)",
_idx,
_inactive,
config.heartbeat_timeout,
_process.pid,
)
_process.terminate()
_process.join(timeout=5)
if _process.is_alive():
_process.kill()
_process.join(timeout=2)
del active_processes[_idx]
if _idx not in recorded_shards:
results.append(
{
"type": "result",
"success": False,
"shard_idx": _idx,
"error": (
f"Unresponsive after {_inactive:.0f}s "
f"(heartbeat timeout: "
f"{config.heartbeat_timeout}s)"
),
"error_category": "heartbeat_timeout",
"duration": _elapsed,
}
)
recorded_shards.add(_idx)
failure_categories["heartbeat_timeout"] += 1
completed_count += 1
pbar.update(1)
pbar.set_postfix(shard=_idx, status="frozen")
# -------------------------------------------------- #
# Progress bar refresh
# -------------------------------------------------- #
if completed_count < n_shards:
_elapsed_total = time.time() - start_time
_mins, _secs = divmod(int(_elapsed_total), 60)
_hrs, _mins = divmod(_mins, 60)
if _hrs > 0:
pbar.set_postfix_str(
f"active={len(active_processes)} "
f"elapsed={_hrs}h{_mins:02d}m"
)
else:
pbar.set_postfix_str(
f"active={len(active_processes)} "
f"elapsed={_mins}m{_secs:02d}s"
)
_ts = time.time()
if _ts - last_status_log >= status_log_interval:
_active_hb = {
k: f"{_ts - last_heartbeat.get(k, _ts):.0f}s"
for k in active_processes
}
run_logger.info(
"CMC status: %d/%d complete; active=%d; "
"launched=%d; heartbeats=%s",
completed_count,
n_shards,
len(active_processes),
shards_launched,
_active_hb,
)
last_status_log = _ts
# Adaptive poll: grow interval during slow periods
_since_completion = time.time() - last_completion_time
if _since_completion > 30.0:
poll_interval = min(poll_interval * 1.1, poll_interval_max)
time.sleep(poll_interval)
# Orphan detection: mark stragglers when no activity remains
if (
not active_processes
and not pending_shards
and completed_count < n_shards
):
_missing = set(range(n_shards)) - recorded_shards
for _idx in sorted(_missing):
results.append(
{
"success": False,
"shard_idx": _idx,
"error": "Shard exited without emitting a result",
"error_category": "crash",
"duration": None,
}
)
recorded_shards.add(_idx)
completed_count += 1
pbar.update(1)
pbar.set_postfix(shard=_idx, status="no-result")
except KeyboardInterrupt:
run_logger.warning("Interrupted — terminating all active processes")
for _idx, (_process, _) in active_processes.items():
run_logger.debug("Terminating shard %d (pid=%s)", _idx, _process.pid)
_process.terminate()
_process.join(timeout=2)
raise
finally:
if pbar is not None:
pbar.close()
for _idx, (_process, _) in list(active_processes.items()):
if _process.is_alive():
run_logger.warning("Cleaning up orphan process for shard %d", _idx)
_process.terminate()
_process.join(timeout=2)
# Restore parent environment to pre-spawn state
for _key, _val in _saved_env.items():
if _val is None:
os.environ.pop(_key, None)
else:
os.environ[_key] = _val
shared_mgr.cleanup()
self._shared_mgr = None
# ---------------------------------------------------------- #
# Collect and validate results
# ---------------------------------------------------------- #
return self._collect_results(results, n_shards, config)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _prepare_shared_data(
self,
shards: list[dict[str, Any]],
config_dict: dict[str, Any],
shared_kwargs: dict[str, Any],
initial_values: dict[str, Any] | None,
ps_dict: dict[str, Any],
) -> tuple[
SharedDataManager,
dict[str, Any],
dict[str, Any],
dict[str, Any],
dict[str, Any] | None,
list[dict[str, Any]],
]:
"""Create all shared memory allocations for worker data.
Called by :meth:`run_shards` before spawning processes.
Args:
shards: Raw shard dicts.
config_dict: Serialised CMCConfig.
shared_kwargs: Shared scalar kwargs (q, dt, phi_angle, …).
initial_values: Optional warm-start values.
ps_dict: Parameter-space config dict.
Returns:
Tuple of ``(mgr, config_ref, kwargs_ref, ps_ref, iv_ref,
shard_refs)``.
"""
shard_data_list: list[dict[str, Any]] = []
for shard in shards:
c2 = shard.get("c2_data")
shard_data_list.append(
{
"c2_data": np.asarray(c2) if c2 is not None else None,
"sigma": (
np.asarray(shard["sigma"])
if shard.get("sigma") is not None
else None
),
"t": (
np.asarray(shard["t"]) if shard.get("t") is not None else None
),
"t1": (
np.asarray(shard["t1"]) if shard.get("t1") is not None else None
),
"t2": (
np.asarray(shard["t2"]) if shard.get("t2") is not None else None
),
"time_grid": (
np.asarray(shard["time_grid"])
if shard.get("time_grid") is not None
else None
),
"weights": (
np.asarray(shard["weights"])
if shard.get("weights") is not None
else None
),
"noise_scale": float(shard.get("noise_scale", 0.1)),
}
)
mgr = SharedDataManager()
try:
config_ref = mgr.create_shared_dict("config", config_dict)
kwargs_ref = mgr.create_shared_dict("kwargs", shared_kwargs)
ps_ref = mgr.create_shared_dict("ps", ps_dict)
iv_ref: dict[str, Any] | None = None
if initial_values is not None:
iv_ref = mgr.create_shared_dict("init_vals", initial_values)
shard_refs = mgr.create_shared_shard_arrays(shard_data_list)
except Exception: # noqa: BLE001 — cleanup-and-reraise; must run mgr.cleanup() on any failure
mgr.cleanup()
raise
return mgr, config_ref, kwargs_ref, ps_ref, iv_ref, shard_refs
def _create_worker_configs(
self,
shards: list[dict[str, Any]],
config: CMCConfig,
) -> list[dict[str, Any]]:
"""Build per-shard lightweight dicts for LPT cost estimation.
Args:
shards: Raw shard dicts.
config: CMC configuration (unused; reserved for future use).
Returns:
List of dicts with ``c2_data`` and ``noise_scale`` keys
(sufficient for :func:`_compute_lpt_schedule`).
"""
return [
{
"c2_data": shards[i].get("c2_data"),
"noise_scale": float(shards[i].get("noise_scale", 0.1)),
}
for i in range(len(shards))
]
def _collect_results(
self,
results: list[dict[str, Any]],
n_shards: int,
config: CMCConfig,
) -> list[dict[str, Any]]:
"""Gather and validate worker results.
Filters failed shards, validates sample integrity, and checks
success rate against config thresholds.
Args:
results: Raw result dicts from worker processes.
n_shards: Total shard count (for rate computation).
config: CMC configuration (carries success-rate thresholds).
Returns:
Validated successful result dicts.
Raises:
RuntimeError: If all shards fail.
"""
successful: list[dict[str, Any]] = []
for res in results:
if res.get("success"):
try:
_validate_worker_result(res)
successful.append(res)
except ValueError as _ve:
logger.warning(
"Shard %s validation failed: %s",
res.get("shard_idx", "?"),
_ve,
)
else:
_err_cat = res.get("error_category", "unknown")
logger.warning(
"Shard %s failed [%s]: %s",
res.get("shard_idx", "?"),
_err_cat,
res.get("error", "unknown"),
)
if res.get("traceback"):
_tb_log = (
logger.error
if _err_cat in {"config_error", "init_crash", "unknown"}
else logger.debug
)
_tb_log(
"Shard %s traceback:\n%s",
res.get("shard_idx", "?"),
res["traceback"],
)
if not successful:
error_categories_summary: dict[str, int] = {}
for res in results:
if not res.get("success"):
_cat = res.get("error_category", "unknown")
error_categories_summary[_cat] = (
error_categories_summary.get(_cat, 0) + 1
)
# Log as ERROR and return empty list — do NOT raise. fit_cmc_sharded
# pads missing shards with failed placeholders and _combine_shard_posteriors
# returns a degenerate CMCResult (all_shards_failed=True) gracefully.
# Raising here crashed the CLI without saving any diagnostic output
# (het_c7548ee8 failure mode: all 47 shards timeout with no NLSQ warmstart).
logger.error(
"All %d shards failed. Error categories: %s. "
"Returning empty result list — caller will produce a degenerate CMCResult. "
"Most likely cause: no NLSQ warm-start provided for a large dataset. "
"Fix: run NLSQ first (optimizer: nlsq) then re-run CMC.",
n_shards,
error_categories_summary,
)
return []
success_rate = len(successful) / n_shards
if success_rate < config.min_success_rate_warning:
logger.warning(
"Success rate %.1f%% below warning threshold %.1f%% "
"— consider investigating failed shards",
success_rate * 100,
config.min_success_rate_warning * 100,
)
if success_rate < config.min_success_rate:
# Build per-category shard index list for the error message.
_fail_cats: dict[str, list[int]] = {}
for _res in results:
if not _res.get("success"):
_cat = _res.get("error_category", "unknown")
_fail_cats.setdefault(_cat, []).append(_res.get("shard_idx", -1))
_shard_summary = "; ".join(
f"{_cat}: shards {_idxs}" for _cat, _idxs in sorted(_fail_cats.items())
)
_timeout_n = len(_fail_cats.get("timeout", []))
_advice: list[str] = []
if _timeout_n:
_advice.append(
f"increase per_shard_timeout (currently {config.per_shard_timeout}s)"
" or reduce num_warmup/num_samples"
)
if len(_fail_cats) - (1 if _timeout_n else 0) > 0:
_advice.append(
"inspect shard error logs above for convergence failures"
)
_advice.append("lower min_success_rate in CMCConfig")
# Log as ERROR and continue — do NOT raise. _combine_shard_posteriors
# already handles partial-failure gracefully and produces a degenerate
# CMCResult with convergence_passed=False and diagnostic metadata.
# Raising here prevented any result from being saved to disk.
logger.error(
"CMC shard success rate %.1f%% is below the configured minimum %.1f%%. "
"Failed shards — %s. Suggested fixes: %s. "
"Proceeding with %d/%d successful shards — result will be degenerate.",
success_rate * 100,
config.min_success_rate * 100,
_shard_summary,
"; ".join(_advice),
len(successful),
n_shards,
)
valid_durations = [
res["duration"] for res in successful if res.get("duration") is not None
]
if valid_durations:
_sorted = sorted(valid_durations)
logger.debug(
"Shard timing: n=%d, min=%.1fs, max=%.1fs, median=%.1fs",
len(valid_durations),
min(valid_durations),
max(valid_durations),
_sorted[len(_sorted) // 2],
)
logger.info(
"run_shards complete: %d/%d shards succeeded",
len(successful),
n_shards,
)
return successful
def _handle_worker_failure(
self,
shard_idx: int,
error: Exception,
) -> dict[str, Any]:
"""Build a failure result dict for a shard that raised an exception.
Args:
shard_idx: Shard index.
error: Exception caught from the worker.
Returns:
Failure result dict with ``success=False``, ``shard_idx``,
``error``, ``error_category``, and ``duration`` keys.
"""
import traceback as _tb
error_str = str(error).lower()
if "nan" in error_str or "inf" in error_str:
category = "numerical"
elif "memory" in error_str:
category = "memory_error"
elif "convergence" in error_str:
category = "convergence"
else:
category = "sampling"
return {
"type": "result",
"success": False,
"shard_idx": shard_idx,
"error": str(error),
"error_category": category,
"traceback": _tb.format_exc(),
"duration": 0.0,
}
[docs]
def is_available(self) -> bool:
"""Check whether this backend can run on the current platform.
Returns:
``True`` if the spawn multiprocessing context is available.
"""
try:
mp.get_context(self.spawn_method)
return True
except (ValueError, OSError):
return False
def __repr__(self) -> str:
return (
f"MultiprocessingBackend("
f"n_workers={self.n_workers}, "
f"spawn_method={self.spawn_method!r})"
)