"""Writers for MCMC/CMC analysis results."""
from __future__ import annotations
import math
import os
import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
from heterodyne.io.json_utils import json_safe, save_json
if TYPE_CHECKING:
from heterodyne.optimization.cmc.results import CMCResult
def _tombstone_safe_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
"""Return a JSON-serializable copy of metadata for tombstone failure records.
``shard_diagnostics`` may contain NaN R-hat / ESS arrays from shards that
completed NUTS but failed convergence — including them verbatim would crash
``json.dumps(..., allow_nan=False)``. Replace with a compact count string.
Any remaining non-finite scalar floats are replaced with None.
"""
safe: dict[str, Any] = {}
for k, v in metadata.items():
if k == "shard_diagnostics":
# Per-shard r_hat/ESS arrays can hold NaN — omit the full arrays.
count = len(v) if isinstance(v, (list, tuple)) else "?"
safe[k] = f"<{count} shards, omitted from tombstone>"
elif isinstance(v, float) and not math.isfinite(v):
safe[k] = None
else:
safe[k] = v
return safe
def _float_or_none(x: float) -> float | None:
"""Return float(x), or None if non-finite (NaN/Inf cannot serialize to JSON)."""
v = float(x)
return None if not math.isfinite(v) else v
[docs]
def save_mcmc_results(
result: CMCResult,
output_dir: Path | str,
prefix: str = "mcmc",
) -> dict[str, Path]:
"""Save MCMC/CMC results to files.
Creates:
- {prefix}_summary.json: Parameter summaries with credible intervals
- {prefix}_diagnostics.json: Convergence diagnostics (R-hat, ESS)
- {prefix}_samples.npz: Full posterior samples (compressed)
Args:
result: CMC result object
output_dir: Output directory
prefix: Filename prefix
Returns:
Dict mapping file type to saved path
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
saved_paths: dict[str, Path] = {}
# Degenerate result: all shards failed → write a tombstone instead of
# attempting to serialize NaN-filled arrays (which would crash json_safe).
if result.metadata.get("all_shards_failed"):
tombstone: dict[str, Any] = {
"status": "failed",
"reason": "all_shards_failed",
"parameter_names": result.parameter_names,
"metadata": _tombstone_safe_metadata(result.metadata),
"timestamp": datetime.now().isoformat(),
}
tombstone_path = output_dir / f"{prefix}_summary.json"
save_json(tombstone, tombstone_path)
saved_paths["summary"] = tombstone_path
return saved_paths
# Stage all files in a temp directory, then move atomically
with tempfile.TemporaryDirectory(dir=str(output_dir.parent)) as tmp_dir:
tmp = Path(tmp_dir)
# Summary file
summary_data = {
"parameter_names": result.parameter_names,
"posterior_mean": json_safe(result.posterior_mean),
"posterior_std": json_safe(result.posterior_std),
"credible_intervals": json_safe(result.credible_intervals),
"map_estimate": json_safe(result.map_estimate)
if result.map_estimate is not None
else None,
"timestamp": datetime.now().isoformat(),
"num_samples": result.num_samples,
"num_chains": result.num_chains,
}
save_json(summary_data, tmp / f"{prefix}_summary.json")
# Diagnostics file
save_mcmc_diagnostics(result, tmp / f"{prefix}_diagnostics.json")
# Samples file (NPZ for efficiency)
_save_posterior_samples(result, tmp / f"{prefix}_samples.npz")
# Move all staged files to the real output directory
# Use os.replace for atomic same-device moves; fall back to
# shutil.move for cross-device (e.g. NFS tmpdir → local output)
for f in tmp.iterdir():
dest = str(output_dir / f.name)
try:
os.replace(str(f), dest)
except OSError:
shutil.move(str(f), dest)
saved_paths["summary"] = output_dir / f"{prefix}_summary.json"
saved_paths["diagnostics"] = output_dir / f"{prefix}_diagnostics.json"
saved_paths["samples"] = output_dir / f"{prefix}_samples.npz"
return saved_paths
[docs]
def save_mcmc_diagnostics(
result: CMCResult,
output_path: Path | str,
r_hat_threshold: float = 1.1,
min_bfmi: float = 0.3,
) -> Path:
"""Save MCMC convergence diagnostics.
Args:
result: CMC result object
output_path: Output file path
r_hat_threshold: R-hat convergence threshold (default 1.1)
min_bfmi: Minimum BFMI threshold (default 0.3)
Returns:
Path to saved file
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
diagnostics: dict[str, Any] = {
"convergence_passed": result.convergence_passed,
"parameter_diagnostics": {},
}
for i, name in enumerate(result.parameter_names):
param_diag: dict[str, Any] = {}
if result.r_hat is not None:
r_hat_val = _float_or_none(result.r_hat[i])
param_diag["r_hat"] = r_hat_val
param_diag["r_hat_passed"] = (
None if r_hat_val is None else bool(r_hat_val < r_hat_threshold)
)
if result.ess_bulk is not None:
param_diag["ess_bulk"] = _float_or_none(result.ess_bulk[i])
if result.ess_tail is not None:
param_diag["ess_tail"] = _float_or_none(result.ess_tail[i])
diagnostics["parameter_diagnostics"][name] = param_diag
# Overall statistics
if result.r_hat is not None:
finite_r_hat = result.r_hat[np.isfinite(result.r_hat)]
diagnostics["max_r_hat"] = _float_or_none(np.max(result.r_hat))
diagnostics["r_hat_threshold"] = r_hat_threshold
diagnostics["all_r_hat_passed"] = bool(
len(finite_r_hat) == len(result.r_hat)
and np.all(finite_r_hat < r_hat_threshold)
)
if result.ess_bulk is not None:
diagnostics["min_ess_bulk"] = _float_or_none(np.min(result.ess_bulk))
if result.bfmi is not None:
diagnostics["bfmi"] = json_safe(result.bfmi)
diagnostics["bfmi_passed"] = bool(np.all(np.array(result.bfmi) > min_bfmi))
diagnostics["sampling_info"] = {
"num_warmup": result.num_warmup,
"num_samples": result.num_samples,
"num_chains": result.num_chains,
"wall_time_seconds": result.wall_time_seconds,
}
save_json(diagnostics, output_path)
return output_path
def _save_posterior_samples(
result: CMCResult,
output_path: Path,
) -> Path:
"""Save posterior samples to NPZ file.
Args:
result: CMC result object
output_path: Output file path
Returns:
Path to saved file
"""
arrays: dict[str, Any] = {
"parameter_names": np.array(result.parameter_names, dtype="U64"),
}
# Save samples for each parameter
if result.samples is not None:
for name, samples in result.samples.items():
arrays[f"samples_{name}"] = np.asarray(samples)
# Save diagnostics arrays
if result.r_hat is not None:
arrays["r_hat"] = np.asarray(result.r_hat)
if result.ess_bulk is not None:
arrays["ess_bulk"] = np.asarray(result.ess_bulk)
if result.ess_tail is not None:
arrays["ess_tail"] = np.asarray(result.ess_tail)
np.savez_compressed(output_path, **arrays)
return output_path