Source code for heterodyne.viz.mcmc_arviz

"""ArviZ-based MCMC visualization for heterodyne analysis.

Provides publication-quality trace, posterior, and pair plots using ArviZ,
with graceful fallback to the basic mcmc_plots module when ArviZ is not
installed.
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any

from heterodyne.utils.logging import get_logger

if TYPE_CHECKING:
    from matplotlib.figure import Figure

    from heterodyne.optimization.cmc.results import CMCResult

logger = get_logger(__name__)


def _has_arviz() -> bool:
    """Check if ArviZ is available."""
    try:
        import arviz  # noqa: F401

        return True
    except ImportError:
        return False


[docs] def to_inference_data(cmc_result: CMCResult) -> Any: """Convert a CMCResult to an ArviZ InferenceData object. Delegates to ``cmc_result_to_arviz`` in the results module. Args: cmc_result: Completed CMC analysis result with posterior samples. Returns: ``arviz.InferenceData`` with posterior group. Raises: ImportError: If ArviZ is not installed. ValueError: If samples are empty. """ from heterodyne.optimization.cmc.results import cmc_result_to_arviz return cmc_result_to_arviz(cmc_result)
[docs] def plot_arviz_trace( result: CMCResult, var_names: list[str] | None = None, save_path: Path | str | None = None, ) -> Figure | None: """Plot MCMC trace using ArviZ, with fallback to basic plots. Args: result: CMC result with posterior samples. var_names: Parameter names to plot. None plots all. save_path: Path to save the figure. None displays interactively. Returns: Matplotlib Figure, or None if ArviZ not available and fallback used. """ if not _has_arviz(): logger.warning("ArviZ not installed; falling back to basic trace plot") from heterodyne.viz.mcmc_plots import plot_trace return plot_trace(result, save_path=save_path) import arviz as az import matplotlib.pyplot as plt idata = to_inference_data(result) axes = az.plot_trace(idata, var_names=var_names, compact=True) fig = axes.ravel()[0].get_figure() if save_path is not None: fig.savefig(str(save_path), dpi=150, bbox_inches="tight") logger.info("Saved ArviZ trace plot to %s", save_path) plt.close(fig) return fig # type: ignore[no-any-return]
[docs] def plot_arviz_posterior( result: CMCResult, var_names: list[str] | None = None, hdi_prob: float = 0.95, save_path: Path | str | None = None, ) -> Figure | None: """Plot posterior distributions using ArviZ. Args: result: CMC result with posterior samples. var_names: Parameter names to plot. None plots all. hdi_prob: Highest density interval probability. Default 0.95. save_path: Path to save the figure. Returns: Matplotlib Figure, or None on fallback. """ if not _has_arviz(): logger.warning("ArviZ not installed; falling back to basic posterior plot") from heterodyne.viz.mcmc_plots import plot_posterior return plot_posterior(result, save_path=save_path) import arviz as az import matplotlib.pyplot as plt idata = to_inference_data(result) # az.plot_posterior was removed in ArviZ 1.1.0 — fall back to built-in plots. if not hasattr(az, "plot_posterior"): from heterodyne.viz.mcmc_plots import plot_posterior return plot_posterior(result, params=var_names, save_path=save_path) axes = az.plot_posterior(idata, var_names=var_names, hdi_prob=hdi_prob) if hasattr(axes, "ravel"): fig = axes.ravel()[0].get_figure() else: fig = axes.get_figure() if save_path is not None: fig.savefig(str(save_path), dpi=150, bbox_inches="tight") logger.info("Saved ArviZ posterior plot to %s", save_path) plt.close(fig) return fig # type: ignore[no-any-return]
[docs] def plot_arviz_pair( result: CMCResult, var_names: list[str] | None = None, save_path: Path | str | None = None, ) -> Figure | None: """Plot pairwise posterior relationships using ArviZ. Args: result: CMC result with posterior samples. var_names: Parameter names to include. None uses all. save_path: Path to save the figure. Returns: Matplotlib Figure. """ if not _has_arviz(): logger.warning("ArviZ not installed; skipping pair plot") return None import arviz as az import matplotlib.pyplot as plt idata = to_inference_data(result) axes = az.plot_pair( idata, var_names=var_names, kind="kde", marginals=True, ) fig = axes.ravel()[0].get_figure() if hasattr(axes, "ravel") else axes.get_figure() if save_path is not None: fig.savefig(str(save_path), dpi=150, bbox_inches="tight") logger.info("Saved ArviZ pair plot to %s", save_path) plt.close(fig) return fig # type: ignore[no-any-return]
def _create_empty_figure(title: str = "No data available") -> Figure: """Create a minimal placeholder figure with a centered message. Used as a safe fallback when a plotting function cannot produce a meaningful output (e.g., empty posterior samples, ArviZ conversion failure). Args: title: Text to display in the centre of the figure. Returns: Matplotlib Figure containing only the message text. """ import matplotlib.pyplot as plt fig, ax = plt.subplots(1, 1, figsize=(6, 4)) ax.text(0.5, 0.5, title, ha="center", va="center", fontsize=14, color="gray") ax.set_axis_off() return fig