Source code for heterodyne.optimization.cmc.backends.worker_pool

"""Worker pool backend for multi-shard CMC execution.

Spawns persistent workers that each run MCMC on assigned shards.
Amortizes JAX/NumPyro initialization overhead across tasks.
"""

from __future__ import annotations

import multiprocessing
import multiprocessing.context
import multiprocessing.process
import os
import queue
import threading
from collections.abc import Callable
from concurrent.futures import ProcessPoolExecutor
from typing import TYPE_CHECKING, Any, cast

from heterodyne.utils.logging import get_logger

if TYPE_CHECKING:
    import jax.numpy as jnp

    from heterodyne.optimization.cmc.config import CMCConfig

logger = get_logger(__name__)


def _estimate_physical_workers() -> int:
    """Estimate optimal worker count from physical core topology.

    Uses ``detect_cpu_info()`` for accurate physical core detection
    (lscpu on Linux, sysctl on macOS), reserving one core for the
    main process. Falls back to ``os.cpu_count() // 2`` if detection
    fails or returns suspicious values.

    Returns:
        Number of worker processes (>= 1).
    """
    try:
        from heterodyne.device.cpu import detect_cpu_info

        info = detect_cpu_info()
        physical = info.physical_cores
        # Sanity: physical should be <= logical and >= 1
        if 1 <= physical <= info.logical_cores:
            return max(1, physical - 1)
    except Exception:  # noqa: BLE001
        pass

    # Fallback: assume hyperthreading (2 threads/core)
    logical = os.cpu_count() or 2
    physical_estimate = max(1, logical // 2)
    return max(1, physical_estimate - 1)


def _run_shard_worker(
    model_fn: Callable[..., Any],
    config_dict: dict[str, Any],
    shard_data: dict[str, Any],
    seed: int,
) -> dict[str, Any]:
    """Per-shard MCMC worker function.

    Runs in a subprocess. Imports NumPyro locally to avoid
    fork-safety issues with JAX.

    Args:
        model_fn: NumPyro model factory (must be picklable).
        config_dict: Serialized CMCConfig.
        shard_data: Data for this shard.
        seed: Random seed for this shard.

    Returns:
        Dictionary of posterior samples.
    """
    import jax
    from numpyro.infer import MCMC, NUTS

    # Ensure CPU backend in spawned worker process
    jax.config.update("jax_platform_name", "cpu")

    rng_key = jax.random.PRNGKey(seed)

    kernel = NUTS(
        model_fn,
        target_accept_prob=config_dict["target_accept_prob"],
        max_tree_depth=config_dict.get("max_tree_depth", 10),
    )

    mcmc = MCMC(
        kernel,
        num_warmup=config_dict["num_warmup"],
        num_samples=config_dict["num_samples"],
        num_chains=1,
        chain_method="sequential",
        progress_bar=False,
    )

    mcmc.run(rng_key)
    return dict(mcmc.get_samples())


[docs] class WorkerPoolBackend: """Persistent worker pool for multi-shard CMC execution. Distributes MCMC shards across a pool of worker processes. Each worker runs one chain per shard, and results are combined. """
[docs] def __init__(self, n_workers: int | None = None) -> None: """Initialize with optional worker count. Args: n_workers: Number of workers. Defaults to physical_cores - 1 (reserving one core for the main process). Uses ``detect_cpu_info()`` for accurate physical core detection; falls back to ``os.cpu_count() // 2`` if detection fails. """ if n_workers is None: n_workers = _estimate_physical_workers() self._n_workers = max(1, n_workers)
@property def n_workers(self) -> int: return self._n_workers
[docs] def get_name(self) -> str: return "worker_pool"
[docs] @staticmethod def should_use_pool(n_shards: int, n_workers: int) -> bool: """Check if pool execution is beneficial. Homodyne parity: pool is used whenever there are at least 3 shards, regardless of the worker count. The pool's startup cost is amortised across shards, not workers, so the gate only needs to reject trivially small shard counts. Args: n_shards: Number of data shards. n_workers: Available workers (unused; kept for API stability). Returns: True when parallelism would be beneficial. """ _ = n_workers return n_shards >= 3
[docs] def run( self, model: Callable[..., Any], config: CMCConfig, rng_key: jnp.ndarray, init_params: dict[str, jnp.ndarray] | None = None, ) -> dict[str, Any]: """Run MCMC via worker pool. Args: model: NumPyro model function. config: CMC configuration. rng_key: JAX PRNG key (used to generate per-shard seeds). init_params: Optional initial values (not used in pool mode). Returns: Combined posterior samples from all workers. """ import jax import jax.numpy as jnp n_chains = config.num_chains logger.info( "WorkerPoolBackend: dispatching %d chains across %d workers", n_chains, self._n_workers, ) config_dict = { "num_warmup": config.num_warmup, "num_samples": config.num_samples, "target_accept_prob": config.target_accept_prob, "max_tree_depth": config.max_tree_depth, } # Generate deterministic seeds per chain seeds = [ int( jax.random.randint( jax.random.fold_in(rng_key, i), (), 0, 2**31 - 1, ) ) for i in range(n_chains) ] with ProcessPoolExecutor( max_workers=self._n_workers, mp_context=multiprocessing.get_context("spawn"), ) as pool: futures = [ pool.submit(_run_shard_worker, model, config_dict, {}, seed) for seed in seeds ] all_samples: dict[str, list[Any]] = {} for future in futures: samples = future.result() for name, arr in samples.items(): all_samples.setdefault(name, []).append(arr) combined = {name: jnp.concatenate(arrs) for name, arrs in all_samples.items()} logger.info("WorkerPoolBackend: combined %d chains", n_chains) return combined
# --------------------------------------------------------------------------- # PersistentWorkerPool — homodyne CMC parity (queue-based event loop) # --------------------------------------------------------------------------- # # ``WorkerPoolBackend`` above wraps :class:`ProcessPoolExecutor` and only # amortises the worker process startup cost. The class below mirrors # homodyne's :class:`WorkerPool` architecture: each worker runs an event # loop with per-worker task queues, a ready-signal handshake, a bounded # shared result queue, round-robin task submit, and a defensive # MemoryError drain so an OOM in one worker cannot deadlock the parent.
[docs] def should_use_persistent_pool(n_shards: int, n_workers: int) -> bool: """Homodyne-parity gate: persistent pool helps when there are >=3 shards.""" _ = n_workers return n_shards >= 3
[docs] class PersistentWorkerPool: """Persistent process pool for CMC shard dispatch (homodyne parity). Workers are spawned once, perform a one-time initialization (e.g. JAX warm-up / JIT priming) under control of ``worker_init_fn``, signal readiness, and then process tasks from a per-worker queue until a ``None`` sentinel is received. The pool exposes a round-robin :meth:`submit`, a blocking :meth:`get_result`, and a deterministic :meth:`shutdown` that drains task queues, joins processes, and terminates any that refuse to exit. The class supports the context-manager protocol; ``__exit__`` calls :meth:`shutdown`. Args: n_workers: Number of persistent worker processes. worker_fn: Picklable module-level callable invoked per task. Signature ``worker_fn(task: dict, **init_kwargs) -> dict | None``. Returning ``None`` skips putting a result on the shared queue, which is useful when ``worker_fn`` already manages its own result emission. worker_init_kwargs: One-time kwargs forwarded to both ``worker_init_fn`` (if provided) and every ``worker_fn`` call. Must be picklable. worker_init_fn: Optional one-time initializer with signature ``init_fn(worker_id: int, **init_kwargs) -> None``. Use it for expensive setup like JAX backend selection or JIT pre-warming. startup_timeout: Maximum seconds to wait for all workers to signal readiness before continuing with whatever subset is ready. """
[docs] def __init__( self, n_workers: int, worker_fn: Callable[..., dict[str, Any] | None], worker_init_kwargs: dict[str, Any] | None = None, worker_init_fn: Callable[..., None] | None = None, startup_timeout: float = 120.0, ) -> None: self._n_workers = max(1, n_workers) self._worker_fn = worker_fn self._init_kwargs = dict(worker_init_kwargs or {}) self._init_fn = worker_init_fn self._startup_timeout = startup_timeout ctx = multiprocessing.get_context("spawn") self._task_queues: list[multiprocessing.Queue] = [ ctx.Queue() for _ in range(self._n_workers) ] self._result_queue: multiprocessing.Queue = ctx.Queue( maxsize=self._n_workers * 4 ) self._processes: list[Any] = [] self._next_worker = 0 self._alive = False self._lock = threading.Lock() self._start_workers(ctx)
def _start_workers(self, ctx: Any) -> None: ready_queue: multiprocessing.Queue = ctx.Queue() for i in range(self._n_workers): p = ctx.Process( target=_persistent_worker_event_loop, args=( i, self._task_queues[i], self._result_queue, self._worker_fn, self._init_kwargs, self._init_fn, ready_queue, ), daemon=True, ) p.start() self._processes.append(p) ready_count = 0 for _ in range(self._n_workers): try: ready_queue.get(timeout=self._startup_timeout) ready_count += 1 except queue.Empty: logger.warning( "PersistentWorkerPool startup timed out after %.0fs (%d/%d ready)", self._startup_timeout, ready_count, self._n_workers, ) break try: ready_queue.close() except (OSError, ValueError): pass self._alive = True logger.info( "PersistentWorkerPool started: %d/%d workers ready", ready_count, self._n_workers, ) @property def n_workers(self) -> int: return self._n_workers @property def result_queue(self) -> multiprocessing.Queue: return self._result_queue
[docs] def is_alive(self) -> bool: return self._alive and any(p.is_alive() for p in self._processes)
[docs] def submit(self, task: dict[str, Any]) -> None: """Round-robin submit ``task`` to the next worker's queue.""" with self._lock: worker_idx = self._next_worker % self._n_workers self._task_queues[worker_idx].put(task) self._next_worker += 1
[docs] def get_result(self, timeout: float = 300.0) -> dict[str, Any]: """Block until a result is available; raises ``queue.Empty`` on timeout.""" return cast("dict[str, Any]", self._result_queue.get(timeout=timeout))
[docs] def results_pending(self) -> bool: """True when the shared result queue has at least one entry.""" return not self._result_queue.empty()
[docs] def shutdown(self, timeout: float = 10.0) -> None: """Send ``None`` sentinels to all workers and join them.""" if not self._alive: return for tq in self._task_queues: try: tq.put(None) except (OSError, ValueError): pass for p in self._processes: p.join(timeout=timeout) if p.is_alive(): logger.warning( "PersistentWorkerPool: worker %d did not exit cleanly; terminating", p.pid, ) p.terminate() p.join(timeout=15) if p.is_alive(): p.kill() for tq in self._task_queues: try: tq.close() except (OSError, ValueError): pass try: self._result_queue.close() except (OSError, ValueError): pass self._alive = False logger.info("PersistentWorkerPool shut down")
def __enter__(self) -> PersistentWorkerPool: return self def __exit__(self, *exc: object) -> None: self.shutdown()
def _persistent_worker_event_loop( worker_id: int, task_queue: multiprocessing.Queue, result_queue: multiprocessing.Queue, worker_fn: Callable[..., dict[str, Any] | None], init_kwargs: dict[str, Any], init_fn: Callable[..., None] | None, ready_queue: multiprocessing.Queue | None, ) -> None: """Persistent worker event loop — homodyne-parity. Calls ``init_fn`` once (if provided), signals ``ready_queue`` on successful initialisation, then loops processing tasks until a ``None`` sentinel arrives. Exceptions are caught and surfaced as failure-result payloads on the shared result queue. On ``MemoryError``, the worker drains its remaining task queue with failure payloads so the parent's blocking ``get_result()`` won't hang waiting for results that will never come. """ try: if init_fn is not None: init_fn(worker_id, **init_kwargs) except Exception as exc: # noqa: BLE001 logger.error( "PersistentWorkerPool worker %d init_fn raised: %s", worker_id, exc ) if ready_queue is not None: try: ready_queue.put(worker_id) # still signal so parent doesn't hang except (OSError, ValueError): pass return if ready_queue is not None: try: ready_queue.put(worker_id) except (OSError, ValueError): pass while True: try: task = task_queue.get() except (OSError, EOFError): break if task is None: break task_id = task.get("task_id", "unknown") try: result = worker_fn(task, **init_kwargs) if result is not None: result["worker_id"] = os.getpid() result_queue.put(result) except MemoryError: result_queue.put( { "task_id": task_id, "success": False, "error": "MemoryError: worker ran out of memory", "worker_id": os.getpid(), } ) # Drain any remaining queued tasks so parent get_result doesn't hang. while True: try: remaining = task_queue.get(block=False) if remaining is None: break result_queue.put( { "task_id": remaining.get("task_id", "unknown"), "success": False, "error": "Worker terminated due to MemoryError", "worker_id": os.getpid(), } ) except (queue.Empty, queue.Full, EOFError): break break except Exception as exc: # noqa: BLE001 — surface any error as a result result_queue.put( { "task_id": task_id, "success": False, "error": str(exc), "error_type": type(exc).__name__, "worker_id": os.getpid(), } )