Source code for heterodyne.data.memory_manager

"""Memory budget tracking for large XPCS datasets.

Provides allocation tracking, budget enforcement, and chunk-size
suggestions so that downstream code can stay within a configurable
memory envelope.
"""

from __future__ import annotations

import threading
import time
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any

import numpy as np

from heterodyne.utils.logging import get_logger

logger = get_logger(__name__)

# Fallback when psutil is unavailable
_DEFAULT_BUDGET_BYTES: int = 8 * 1024 * 1024 * 1024  # 8 GB


[docs] @dataclass class MemoryBudget: """Snapshot of memory budget state. Attributes: total_bytes: Total budget available. allocated_bytes: Currently tracked allocations. peak_bytes: Highest allocated_bytes observed. """ total_bytes: int allocated_bytes: int peak_bytes: int
[docs] class MemoryManager: """Track memory allocations against a configurable budget. When *budget_bytes* is ``None`` the manager auto-detects available system memory via ``psutil.virtual_memory()``, falling back to 8 GB if psutil is not installed. All public methods are thread-safe. Args: budget_bytes: Explicit budget in bytes, or ``None`` for auto-detect. """
[docs] def __init__(self, budget_bytes: int | None = None) -> None: if budget_bytes is not None: if budget_bytes <= 0: raise ValueError("budget_bytes must be positive") self._total = budget_bytes else: self._total = self._detect_system_memory() self._allocated = 0 self._peak = 0 self._labels: dict[str, int] = {} self._lock = threading.Lock() logger.info("MemoryManager initialised with budget %d bytes", self._total)
# ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def request(self, n_bytes: int, label: str) -> bool: """Request an allocation of *n_bytes* tracked under *label*. If the allocation would exceed the budget, the request is denied and ``False`` is returned. Existing allocations with the same label are released first (idempotent re-allocation). Args: n_bytes: Number of bytes to allocate. label: Human-readable label for the allocation. Returns: ``True`` if the allocation fits within the budget. """ if n_bytes < 0: raise ValueError("n_bytes must be non-negative") with self._lock: # Release any prior allocation under the same label prior = self._labels.pop(label, 0) self._allocated -= prior if self._allocated + n_bytes > self._total: # Restore the prior allocation if we cannot fit if prior > 0: self._labels[label] = prior self._allocated += prior logger.warning( "Allocation '%s' denied: %d bytes requested, %d/%d bytes used", label, n_bytes, self._allocated, self._total, ) return False self._labels[label] = n_bytes self._allocated += n_bytes if self._allocated > self._peak: self._peak = self._allocated logger.debug( "Allocated '%s': %d bytes (%d/%d used)", label, n_bytes, self._allocated, self._total, ) return True
[docs] def release(self, label: str) -> None: """Release a tracked allocation. No-op if the label is not currently tracked. Args: label: The label passed to :meth:`request`. """ with self._lock: freed = self._labels.pop(label, 0) self._allocated -= freed if freed > 0: logger.debug( "Released '%s': %d bytes (%d/%d used)", label, freed, self._allocated, self._total, )
[docs] @staticmethod def estimate_array_size( shape: tuple[int, ...], dtype: np.dtype | type = np.float64, ) -> int: """Estimate the memory footprint of an array. Args: shape: Array dimensions. dtype: NumPy dtype (or type convertible to one). Returns: Size in bytes. """ dt = np.dtype(dtype) n_elements = 1 for dim in shape: n_elements *= dim return n_elements * dt.itemsize
[docs] def get_budget(self) -> MemoryBudget: """Return a snapshot of the current budget state. Returns: MemoryBudget dataclass. """ with self._lock: return MemoryBudget( total_bytes=self._total, allocated_bytes=self._allocated, peak_bytes=self._peak, )
[docs] def suggest_chunk_size( self, total_elements: int, element_bytes: int, ) -> int: """Suggest an optimal chunk size that fits within available budget. The returned chunk size will use at most half of the *remaining* budget so that headroom is preserved for intermediate buffers. Args: total_elements: Total number of elements to process. element_bytes: Bytes per element. Returns: Suggested number of elements per chunk (always >= 1). """ with self._lock: remaining = self._total - self._allocated # Use at most half the remaining budget usable = max(remaining // 2, element_bytes) chunk = usable // element_bytes # Clamp to [1, total_elements] return max(1, min(chunk, total_elements))
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ @staticmethod def _detect_system_memory() -> int: """Attempt to auto-detect available system memory.""" try: import psutil mem = psutil.virtual_memory() budget = int(mem.available) logger.info("Auto-detected available memory: %d bytes", budget) return budget except (ImportError, AttributeError): logger.info( "psutil unavailable; using default budget of %d bytes", _DEFAULT_BUDGET_BYTES, ) return _DEFAULT_BUDGET_BYTES
# --------------------------------------------------------------------------- # Feature 6: Memory-Mapped I/O # ---------------------------------------------------------------------------
[docs] class MemoryMapManager: """Manage memory-mapped access to large HDF5 datasets. For datasets that exceed available RAM, this provides read-only memory-mapped access via h5py's direct chunk reading, avoiding full materialization of arrays into memory. Args: max_resident_bytes: Maximum bytes to keep resident in memory at any time. Defaults to 2 GB. """
[docs] def __init__(self, max_resident_bytes: int = 2 * 1024 * 1024 * 1024) -> None: self._max_resident_bytes = max_resident_bytes self._handles: dict[str, Any] = {} # h5py.File handles self._lock = threading.RLock() logger.debug( "MemoryMapManager initialised (max_resident=%d bytes)", max_resident_bytes, )
def _get_handle(self, file_path: Path | str) -> Any: """Return a cached read-only file handle, opening if necessary.""" try: import h5py except ImportError as exc: raise ImportError( "h5py is required for MemoryMapManager; install it via 'uv add h5py'" ) from exc key = str(Path(file_path).resolve()) # Lock already held by caller if key not in self._handles: logger.debug("Opening HDF5 handle: %s", key) self._handles[key] = h5py.File(key, "r") return self._handles[key]
[docs] def open_dataset( self, file_path: Path | str, dataset_path: str, ) -> Any: """Open an HDF5 dataset for memory-mapped-like access. Returns the raw ``h5py.Dataset`` proxy — supports NumPy-style slicing (``ds[:, 100:200]``) and reads chunks on demand without materializing the full array. Use :meth:`read_slice` for explicit partial reads or :meth:`materialize` when the caller truly needs an in-memory ndarray. Args: file_path: Path to HDF5 file. dataset_path: Internal HDF5 dataset path (e.g., "/exchange/C2T_all/c2_00001"). Returns: ``h5py.Dataset`` proxy with lazy chunk access. Raises: MemoryError: If the dataset size exceeds ``max_resident_bytes`` and the caller would have to materialize it — guard against accidental OOM by calling :meth:`read_slice` instead. """ with self._lock: handle = self._get_handle(file_path) dataset = handle[dataset_path] estimated = self.estimate_dataset_size(file_path, dataset_path) if estimated > self._max_resident_bytes: logger.warning( "Dataset '%s' estimated size %d bytes exceeds max_resident %d bytes; " "returning lazy h5py.Dataset proxy — use read_slice() for partial access " "or materialize() if full in-memory copy is required", dataset_path, estimated, self._max_resident_bytes, ) return dataset
[docs] def materialize( self, file_path: Path | str, dataset_path: str, ) -> np.ndarray: """Eagerly load an HDF5 dataset into memory as an ``ndarray``. Raises ``MemoryError`` if the estimated size exceeds ``max_resident_bytes`` — callers must explicitly opt out of the limit by raising it, or use :meth:`read_slice` for partial loads. """ with self._lock: handle = self._get_handle(file_path) dataset = handle[dataset_path] estimated = self.estimate_dataset_size(file_path, dataset_path) if estimated > self._max_resident_bytes: raise MemoryError( f"Refusing to materialize dataset '{dataset_path}' " f"({estimated} bytes > max_resident {self._max_resident_bytes} bytes); " f"use read_slice() or raise max_resident_bytes" ) result: np.ndarray = np.asarray(dataset) return result
[docs] def read_slice( self, file_path: Path | str, dataset_path: str, slices: tuple[slice, ...], ) -> np.ndarray: """Read a specific slice from an HDF5 dataset without loading the full array. Args: file_path: Path to HDF5 file. dataset_path: Internal HDF5 dataset path. slices: Tuple of slice objects defining the region to read. Returns: NumPy array of the requested slice. """ with self._lock: handle = self._get_handle(file_path) result: np.ndarray = np.asarray(handle[dataset_path][slices]) return result
[docs] def estimate_dataset_size( self, file_path: Path | str, dataset_path: str, ) -> int: """Estimate the in-memory size of an HDF5 dataset without loading it. Args: file_path: Path to HDF5 file. dataset_path: Internal HDF5 dataset path. Returns: Estimated size in bytes. """ with self._lock: handle = self._get_handle(file_path) dataset = handle[dataset_path] n_elements = 1 for dim in dataset.shape: n_elements *= dim return int(n_elements * dataset.dtype.itemsize)
[docs] def close_all(self) -> None: """Close all open HDF5 file handles.""" with self._lock: for key, handle in list(self._handles.items()): logger.debug("Closing HDF5 handle: %s", key) handle.close() self._handles.clear()
[docs] def __enter__(self) -> MemoryMapManager: """Support use as a context manager.""" return self
[docs] def __exit__(self, *args: object) -> None: """Close all handles on context exit.""" self.close_all()
# --------------------------------------------------------------------------- # Feature 7: Adaptive Chunking # ---------------------------------------------------------------------------
[docs] @dataclass class ChunkInfo: """Metadata for a single processing chunk. Attributes: start: Start index along the batch axis. end: End index (exclusive) along the batch axis. size_bytes: Estimated memory footprint of this chunk. priority: Processing priority (lower = higher priority). """ start: int end: int size_bytes: int priority: int = 0
[docs] class AdaptiveChunker: """Compute chunk sizes that adapt to available memory and data characteristics. Unlike fixed chunking, this class monitors memory pressure and adjusts chunk sizes dynamically. Chunks near the diagonal of correlation matrices (small time lag) are given higher priority since they carry more signal. Args: memory_manager: MemoryManager instance for budget awareness. safety_factor: Fraction of available memory to actually use (default 0.5). """
[docs] def __init__( self, memory_manager: MemoryManager, safety_factor: float = 0.5, ) -> None: if not 0.0 < safety_factor <= 1.0: raise ValueError("safety_factor must be in (0, 1]") self._memory_manager = memory_manager self._safety_factor = safety_factor logger.debug("AdaptiveChunker initialised (safety_factor=%.2f)", safety_factor)
[docs] def compute_chunks( self, total_elements: int, element_bytes: int, prioritize_near_diagonal: bool = False, ) -> list[ChunkInfo]: """Compute adaptive chunk boundaries. Args: total_elements: Total number of elements along the batch axis. element_bytes: Memory per element in bytes. prioritize_near_diagonal: If True, assign lower priority numbers (= higher priority) to chunks covering small indices. Returns: List of ChunkInfo objects defining the chunking strategy. """ if total_elements <= 0: raise ValueError("total_elements must be positive") if element_bytes <= 0: raise ValueError("element_bytes must be positive") budget = self._memory_manager.get_budget() remaining = budget.total_bytes - budget.allocated_bytes usable = int(remaining * self._safety_factor) chunk_size = usable // element_bytes # Clamp to [1, total_elements] chunk_size = max(1, min(chunk_size, total_elements)) first_quarter_end = total_elements // 4 chunks: list[ChunkInfo] = [] start = 0 while start < total_elements: end = min(start + chunk_size, total_elements) size_bytes = (end - start) * element_bytes if prioritize_near_diagonal: priority = 0 if start < first_quarter_end else 1 else: priority = 0 chunks.append( ChunkInfo( start=start, end=end, size_bytes=size_bytes, priority=priority, ) ) start = end logger.debug( "AdaptiveChunker: %d chunks of ~%d elements " "(remaining=%d bytes, safety=%.2f)", len(chunks), chunk_size, remaining, self._safety_factor, ) return chunks
# --------------------------------------------------------------------------- # Feature 8: Memory Pressure Monitoring # ---------------------------------------------------------------------------
[docs] class MemoryPressureLevel(Enum): """System memory pressure classification.""" LOW = "low" # < 50% used MODERATE = "moderate" # 50-75% used HIGH = "high" # 75-90% used CRITICAL = "critical" # > 90% used
[docs] class MemoryPressureMonitor: """Monitor system memory pressure and trigger adaptive responses. Polls system memory usage via psutil (with graceful fallback) and classifies pressure into levels that downstream code can use to adjust batch sizes, enable compression, or skip optional caching. Args: poll_interval_seconds: Minimum seconds between actual system polls (cached between polls). Defaults to 5.0. """ # Conservative fallback when psutil is unavailable: assume 4 GB available _FALLBACK_AVAILABLE_BYTES: int = 4 * 1024 * 1024 * 1024 _FALLBACK_PERCENT_USED: float = 50.0
[docs] def __init__(self, poll_interval_seconds: float = 5.0) -> None: if poll_interval_seconds <= 0: raise ValueError("poll_interval_seconds must be positive") self._poll_interval = poll_interval_seconds self._lock = threading.Lock() self._last_poll_time: float = 0.0 self._cached_level: MemoryPressureLevel = MemoryPressureLevel.LOW self._cached_available_bytes: int = self._FALLBACK_AVAILABLE_BYTES logger.debug( "MemoryPressureMonitor initialised (poll_interval=%.1fs)", poll_interval_seconds, )
def _poll(self) -> tuple[MemoryPressureLevel, int]: """Poll system memory; returns (level, available_bytes). Must be called with self._lock held. """ now = time.monotonic() if now - self._last_poll_time < self._poll_interval: return self._cached_level, self._cached_available_bytes try: import psutil vm = psutil.virtual_memory() percent_used = vm.percent available = int(vm.available) except (ImportError, AttributeError): logger.debug("psutil unavailable; using fallback memory figures") percent_used = self._FALLBACK_PERCENT_USED available = self._FALLBACK_AVAILABLE_BYTES if percent_used < 50.0: level = MemoryPressureLevel.LOW elif percent_used < 75.0: level = MemoryPressureLevel.MODERATE elif percent_used < 90.0: level = MemoryPressureLevel.HIGH else: level = MemoryPressureLevel.CRITICAL self._cached_level = level self._cached_available_bytes = available self._last_poll_time = now logger.debug( "MemoryPressureMonitor: %.1f%% used → %s", percent_used, level.value, ) return level, available
[docs] def current_pressure(self) -> MemoryPressureLevel: """Return the current memory pressure level. Uses cached value if polled recently (within poll_interval_seconds). Returns: Current MemoryPressureLevel. """ with self._lock: level, _ = self._poll() return level
[docs] def available_bytes(self) -> int: """Return available system memory in bytes. Returns: Available memory, or a conservative default if psutil unavailable. """ with self._lock: _, available = self._poll() return available
[docs] def should_reduce_allocation(self) -> bool: """Return True if memory pressure suggests reducing allocations. Returns True when pressure is HIGH or CRITICAL. """ level = self.current_pressure() return level in (MemoryPressureLevel.HIGH, MemoryPressureLevel.CRITICAL)
[docs] def recommended_budget_fraction(self) -> float: """Return recommended fraction of total memory to use. Returns: Float in (0, 1]: 1.0 for LOW, 0.75 for MODERATE, 0.5 for HIGH, 0.25 for CRITICAL. """ level = self.current_pressure() fractions: dict[MemoryPressureLevel, float] = { MemoryPressureLevel.LOW: 1.0, MemoryPressureLevel.MODERATE: 0.75, MemoryPressureLevel.HIGH: 0.5, MemoryPressureLevel.CRITICAL: 0.25, } return fractions[level]