Source code for heterodyne.optimization.cmc.io

"""Shard I/O for CMC (Consensus Monte Carlo) results.

Provides functions to persist and reload posterior samples as ``.npz``
archives and ArviZ ``InferenceData`` objects as NetCDF files.
"""

from __future__ import annotations

import json
import math
import re
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np

from heterodyne.utils.logging import get_logger

if TYPE_CHECKING:
    from heterodyne.optimization.cmc.results import CMCResult

logger = get_logger(__name__)


# ------------------------------------------------------------------
# Shard I/O (NumPy .npz)
# ------------------------------------------------------------------


[docs] def save_shard_results( results: dict[str, np.ndarray], output_dir: str | Path, shard_id: int, ) -> Path: """Save posterior samples for a single shard as a ``.npz`` archive. The file is written to ``<output_dir>/shard_<shard_id>.npz``. Args: results: Mapping of parameter name to sample array. output_dir: Directory in which to save the archive. shard_id: Integer shard identifier. Returns: Path to the saved file. """ out = Path(output_dir) out.mkdir(parents=True, exist_ok=True) path = out / f"shard_{shard_id}.npz" np.savez(path, **results) logger.info("Saved shard %d (%d params) to %s", shard_id, len(results), path) return path
[docs] def load_shard_results( output_dir: str | Path, shard_id: int, ) -> dict[str, np.ndarray]: """Load posterior samples for a single shard. Args: output_dir: Directory containing shard archives. shard_id: Integer shard identifier. Returns: Mapping of parameter name to sample array. Raises: FileNotFoundError: If the shard file does not exist. """ path = Path(output_dir) / f"shard_{shard_id}.npz" if not path.exists(): raise FileNotFoundError(f"Shard file not found: {path}") with np.load(path) as data: result = {key: data[key] for key in data.files} logger.info("Loaded shard %d (%d params) from %s", shard_id, len(result), path) return result
[docs] def list_shards(output_dir: str | Path) -> list[int]: """Discover saved shard IDs in *output_dir*. Scans for files matching ``shard_<N>.npz`` and returns a sorted list of the integer shard IDs. Args: output_dir: Directory to scan. Returns: Sorted list of shard IDs found. """ out = Path(output_dir) if not out.is_dir(): return [] pattern = re.compile(r"^shard_(\d+)\.npz$") ids: list[int] = [] for p in out.iterdir(): m = pattern.match(p.name) if m is not None: ids.append(int(m.group(1))) ids.sort() logger.debug("Found %d shards in %s", len(ids), out) return ids
# ------------------------------------------------------------------ # ArviZ InferenceData I/O (NetCDF) # ------------------------------------------------------------------
[docs] def save_inference_data(idata: object, path: str | Path) -> Path: """Save an ArviZ InferenceData object as a NetCDF file. Args: idata: ArviZ InferenceData instance. path: Destination file path (should end in ``.nc``). Returns: Path to the saved file. Raises: ImportError: If ``arviz`` is not installed. """ try: import arviz as az except ImportError as exc: raise ImportError( "arviz is required for InferenceData I/O. Install it with: uv add arviz" ) from exc save_path = Path(path) save_path.parent.mkdir(parents=True, exist_ok=True) az.to_netcdf(idata, str(save_path)) logger.info("Saved InferenceData to %s", save_path) return save_path
[docs] def load_inference_data(path: str | Path) -> object: """Load an ArviZ InferenceData object from a NetCDF file. Args: path: Path to a NetCDF file previously written by :func:`save_inference_data`. Returns: ArviZ InferenceData object. Raises: FileNotFoundError: If the file does not exist. ImportError: If ``arviz`` is not installed. """ try: import arviz as az except ImportError as exc: raise ImportError( "arviz is required for InferenceData I/O. Install it with: uv add arviz" ) from exc load_path = Path(path) if not load_path.exists(): raise FileNotFoundError(f"InferenceData file not found: {load_path}") idata = az.from_netcdf(str(load_path)) logger.info("Loaded InferenceData from %s", load_path) return idata
# --------------------------------------------------------------------------- # Full result serialization pipeline (homodyne parity) # --------------------------------------------------------------------------- SAMPLES_SCHEMA_VERSION = (1, 0) def _r_hat_dict(result: CMCResult) -> dict[str, float]: if result.r_hat is None: return {n: float("nan") for n in result.parameter_names} return {n: float(result.r_hat[i]) for i, n in enumerate(result.parameter_names)} def _ess_bulk_dict(result: CMCResult) -> dict[str, float]: if result.ess_bulk is None: return {n: float("nan") for n in result.parameter_names} return {n: float(result.ess_bulk[i]) for i, n in enumerate(result.parameter_names)} def _ess_tail_dict(result: CMCResult) -> dict[str, float]: if result.ess_tail is None: return {n: float("nan") for n in result.parameter_names} return {n: float(result.ess_tail[i]) for i, n in enumerate(result.parameter_names)}
[docs] def save_samples_npz(result: CMCResult, output_path: Path) -> None: """Save posterior samples as a compressed NPZ archive (ArviZ-compatible). Shape stored: ``posterior_samples`` is ``(n_chains, n_samples, n_params)``. """ samples_3d = result.get_samples_array() names = result.parameter_names r_hat_arr = np.array( [ float(result.r_hat[i]) if result.r_hat is not None else np.nan for i in range(len(names)) ] ) ess_bulk_arr = np.array( [ float(result.ess_bulk[i]) if result.ess_bulk is not None else np.nan for i in range(len(names)) ] ) ess_tail_arr = np.array( [ float(result.ess_tail[i]) if result.ess_tail is not None else np.nan for i in range(len(names)) ] ) # n_phi from metadata (set by fit_cmc_jax); fall back to 0 when absent. n_phi = int(result.metadata.get("n_phi", 0)) np.savez_compressed( output_path, schema_version=np.array(SAMPLES_SCHEMA_VERSION), posterior_samples=samples_3d, param_names=np.array(names), r_hat=r_hat_arr, ess_bulk=ess_bulk_arr, ess_tail=ess_tail_arr, # Prefer first-class field; fall back to metadata for legacy results. divergences=np.array( [ int( getattr(result, "divergences", 0) or result.metadata.get("num_divergences", 0) ) ] ), analysis_mode=np.array([str(result.metadata.get("analysis_mode", "unknown"))]), n_phi=np.array([n_phi]), n_chains=np.array([result.num_chains]), n_samples=np.array([result.num_samples]), ) logger.info("Saved samples.npz: %s %s", output_path, samples_3d.shape)
[docs] def load_samples_npz(input_path: Path) -> dict[str, Any]: """Load samples NPZ and return a plain Python dict. Raises ------ FileNotFoundError If the file does not exist. ValueError If the suffix is not ``.npz`` or the path is not a regular file. """ input_path = Path(input_path).resolve() if not input_path.exists(): raise FileNotFoundError(f"Samples file not found: {input_path}") if input_path.suffix != ".npz": raise ValueError(f"Expected .npz file, got: {input_path.suffix}") if not input_path.is_file(): raise ValueError(f"Path is not a regular file: {input_path}") # NPZ archives store raw numpy binary arrays; no object serialization is used. # Pin allow_pickle=False so a tampered archive cannot execute arbitrary code # via object arrays (homodyne-parity, io.load_samples_npz:155). # Use context manager to ensure the underlying zip file is closed after reading. with np.load(input_path, allow_pickle=False) as data: # noqa: NPY002 result = { "schema_version": tuple(data["schema_version"]), "posterior_samples": np.array(data["posterior_samples"]), "param_names": data["param_names"].tolist(), "r_hat": np.array(data["r_hat"]), "ess_bulk": np.array(data["ess_bulk"]), "ess_tail": np.array(data["ess_tail"]), "divergences": int(data["divergences"][0]), "analysis_mode": str(data["analysis_mode"][0]), "n_phi": int(data["n_phi"][0]), "n_chains": int(data["n_chains"][0]), "n_samples": int(data["n_samples"][0]), } return result
[docs] def samples_to_arviz(samples_data: dict[str, Any]) -> Any: """Convert loaded samples dict to ArviZ InferenceData. Parameters ---------- samples_data : dict[str, Any] Data returned by :func:`load_samples_npz`. """ import arviz as az # type: ignore[import-untyped] samples = samples_data["posterior_samples"] # (n_chains, n_samples, n_params) param_names = samples_data["param_names"] posterior = {name: samples[:, :, i] for i, name in enumerate(param_names)} return az.from_dict({"posterior": posterior})
[docs] def save_fitted_data_npz( result: CMCResult, c2_exp: np.ndarray, c2_fitted: np.ndarray, c2_fitted_std: np.ndarray, t1: np.ndarray, t2: np.ndarray, phi_angles: np.ndarray, q: float, output_path: Path, ) -> None: """Save fitted C2 data in NLSQ-compatible NPZ format.""" residuals = c2_exp - c2_fitted c2_fitted_5pct = c2_fitted - 1.645 * c2_fitted_std c2_fitted_95pct = c2_fitted + 1.645 * c2_fitted_std np.savez_compressed( output_path, c2_exp=c2_exp, c2_fitted=c2_fitted, residuals=residuals, q=np.array([q]), phi_angles=phi_angles, t1=t1, t2=t2, c2_fitted_std=c2_fitted_std, c2_fitted_5pct=c2_fitted_5pct, c2_fitted_95pct=c2_fitted_95pct, ) logger.info("Saved fitted_data.npz: %s shape=%s", output_path, c2_exp.shape)
[docs] def save_parameters_json(result: CMCResult, output_path: Path) -> None: """Save posterior statistics per parameter to JSON. NaN -> null, Inf -> "Infinity"/"-Infinity". """ stats = result.get_posterior_stats() def _san(v: float) -> float | str | None: if math.isnan(v): return None if math.isinf(v): return "Infinity" if v > 0 else "-Infinity" return v stats_json = { name: {k: _san(float(val)) for k, val in ps.items()} for name, ps in stats.items() } with open(output_path, "w", encoding="utf-8") as f: json.dump(stats_json, f, indent=2) logger.info("Saved parameters.json: %s (%d params)", output_path, len(stats_json))
[docs] def save_diagnostics_json( result: CMCResult, output_path: Path, warnings: list[str] | None = None, ) -> None: """Save convergence diagnostics to JSON.""" from heterodyne.optimization.cmc.diagnostics import create_diagnostics_dict convergence_status = "converged" if result.convergence_passed else "not_converged" # Prefer first-class fields on CMCResult; fall back to legacy metadata keys. num_shards = int( getattr(result, "num_shards", 0) or result.metadata.get("n_shards", 1) ) divergences = int( getattr(result, "divergences", 0) or result.metadata.get("num_divergences", 0) ) exec_time = result.wall_time_seconds or 0.0 diag = create_diagnostics_dict( r_hat=_r_hat_dict(result), ess_bulk=_ess_bulk_dict(result), ess_tail=_ess_tail_dict(result), divergences=divergences, convergence_status=convergence_status, warnings=warnings or [], n_chains=result.num_chains, n_warmup=result.num_warmup, n_samples=result.num_samples, warmup_time=0.0, sampling_time=exec_time, num_shards=num_shards, ) def _conv(obj: Any) -> Any: if isinstance(obj, np.ndarray): return _conv(obj.tolist()) if isinstance(obj, (np.integer, np.floating)): v = float(obj) if math.isnan(v): return None if math.isinf(v): return "Infinity" if v > 0 else "-Infinity" return v if isinstance(obj, dict): return {k: _conv(v) for k, v in obj.items()} if isinstance(obj, list): return [_conv(i) for i in obj] if isinstance(obj, float): if math.isnan(obj): return None if math.isinf(obj): return "Infinity" if obj > 0 else "-Infinity" return obj with open(output_path, "w", encoding="utf-8") as f: json.dump(_conv(diag), f, indent=2) logger.info("Saved diagnostics.json: %s", output_path)
[docs] def save_all_results( result: CMCResult, output_dir: Path, c2_exp: np.ndarray | None = None, c2_fitted: np.ndarray | None = None, c2_fitted_std: np.ndarray | None = None, t1: np.ndarray | None = None, t2: np.ndarray | None = None, phi_angles: np.ndarray | None = None, q: float | None = None, ) -> dict[str, Path]: """Save all CMC result files and return a dict of ``{key: path}``.""" output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) saved: dict[str, Path] = {} save_samples_npz(result, output_dir / "samples.npz") saved["samples"] = output_dir / "samples.npz" save_parameters_json(result, output_dir / "parameters.json") saved["parameters"] = output_dir / "parameters.json" save_diagnostics_json(result, output_dir / "diagnostics.json") saved["diagnostics"] = output_dir / "diagnostics.json" if all( x is not None for x in [c2_exp, c2_fitted, c2_fitted_std, t1, t2, phi_angles, q] ): save_fitted_data_npz( result=result, c2_exp=c2_exp, # type: ignore[arg-type] c2_fitted=c2_fitted, # type: ignore[arg-type] c2_fitted_std=c2_fitted_std, # type: ignore[arg-type] t1=t1, # type: ignore[arg-type] t2=t2, # type: ignore[arg-type] phi_angles=phi_angles, # type: ignore[arg-type] q=q, # type: ignore[arg-type] output_path=output_dir / "fitted_data.npz", ) saved["fitted_data"] = output_dir / "fitted_data.npz" return saved