Source code for heterodyne.data.validators

"""Input validation utilities for XPCS data arrays.

Complements the higher-level ``validation.py`` module by providing
fine-grained, composable checks on individual arrays. Each function
returns a list of error strings; an empty list indicates valid input.
"""

from __future__ import annotations

import numpy as np

from heterodyne.utils.logging import get_logger

logger = get_logger(__name__)


[docs] def validate_correlation_shape( c2: np.ndarray, expected_shape: tuple[int, ...] | None = None, ) -> list[str]: """Validate shape of a correlation matrix. Checks that ``c2`` is 2D (single angle) or 3D (multi-angle batch), and optionally matches an expected shape. Args: c2: Correlation array to validate. expected_shape: If provided, the exact expected shape. Returns: List of error messages (empty if valid). """ errors: list[str] = [] if c2.ndim not in (2, 3): errors.append( f"Correlation data must be 2D or 3D, got {c2.ndim}D with shape {c2.shape}" ) return errors # For 2D: (n_t, n_t); for 3D: (n_phi, n_t, n_t) if c2.ndim == 2 and c2.shape[0] != c2.shape[1]: errors.append(f"2D correlation matrix must be square, got shape {c2.shape}") if c2.ndim == 3 and c2.shape[1] != c2.shape[2]: errors.append( f"3D correlation slices must be square, " f"got shape {c2.shape} (axes 1,2 differ)" ) if expected_shape is not None and c2.shape != expected_shape: errors.append(f"Shape mismatch: got {c2.shape}, expected {expected_shape}") return errors
[docs] def validate_time_arrays( t1: np.ndarray, t2: np.ndarray, ) -> list[str]: """Validate time arrays for monotonicity and matching lengths. Args: t1: First time axis array. t2: Second time axis array. Returns: List of error messages (empty if valid). """ errors: list[str] = [] if t1.ndim != 1: errors.append(f"t1 must be 1D, got {t1.ndim}D with shape {t1.shape}") if t2.ndim != 1: errors.append(f"t2 must be 1D, got {t2.ndim}D with shape {t2.shape}") if t1.ndim == 1 and t2.ndim == 1 and len(t1) != len(t2): errors.append( f"Time array lengths must match: len(t1)={len(t1)}, len(t2)={len(t2)}" ) # Monotonicity checks (only for 1D arrays) if t1.ndim == 1 and len(t1) >= 2: if not np.all(np.diff(t1) > 0): errors.append("t1 is not strictly monotonically increasing") if t2.ndim == 1 and len(t2) >= 2: if not np.all(np.diff(t2) > 0): errors.append("t2 is not strictly monotonically increasing") return errors
[docs] def validate_q_range( q: np.ndarray, q_min: float, q_max: float, ) -> list[str]: """Validate that wavevector values fall within the specified range. Args: q: Array of wavevector values. q_min: Minimum allowed wavevector. q_max: Maximum allowed wavevector. Returns: List of error messages (empty if valid). """ errors: list[str] = [] if q_min > q_max: errors.append(f"q_min ({q_min}) must be <= q_max ({q_max})") return errors if q.size == 0: errors.append("Wavevector array is empty") return errors actual_min = float(np.min(q)) actual_max = float(np.max(q)) if actual_min < q_min: errors.append( f"Wavevector values below q_min: min(q)={actual_min:.4g} < {q_min:.4g}" ) if actual_max > q_max: errors.append( f"Wavevector values above q_max: max(q)={actual_max:.4g} > {q_max:.4g}" ) return errors
[docs] def validate_weights( weights: np.ndarray, data_shape: tuple[int, ...], ) -> list[str]: """Validate weight array for non-negativity and shape compatibility. Args: weights: Weight array to validate. data_shape: Expected shape (must match weights shape). Returns: List of error messages (empty if valid). """ errors: list[str] = [] if weights.shape != data_shape: errors.append( f"Weights shape {weights.shape} does not match data shape {data_shape}" ) if np.any(weights < 0): n_negative = int(np.sum(weights < 0)) errors.append( f"Weights must be non-negative: found {n_negative} negative value(s)" ) if np.any(np.isnan(weights)): n_nan = int(np.sum(np.isnan(weights))) errors.append(f"Weights contain {n_nan} NaN value(s)") if np.any(np.isinf(weights)): n_inf = int(np.sum(np.isinf(weights))) errors.append(f"Weights contain {n_inf} infinite value(s)") return errors
[docs] def validate_no_nan( arr: np.ndarray, name: str, ) -> list[str]: """Check that an array contains no NaN or Inf values. Args: arr: Array to check. name: Descriptive name for error messages. Returns: List of error messages (empty if valid). """ errors: list[str] = [] nan_count = int(np.sum(np.isnan(arr))) if nan_count > 0: pct = 100.0 * nan_count / arr.size errors.append( f"'{name}' contains {nan_count} NaN value(s) ({pct:.2f}% of {arr.size} elements)" ) inf_count = int(np.sum(np.isinf(arr))) if inf_count > 0: pct = 100.0 * inf_count / arr.size errors.append( f"'{name}' contains {inf_count} Inf value(s) ({pct:.2f}% of {arr.size} elements)" ) return errors