Source code for heterodyne.viz.mcmc_plots

"""Visualization for MCMC/CMC results."""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

from heterodyne.utils.logging import get_logger
from heterodyne.viz.mcmc_diagnostics import BFMI_THRESHOLD

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure

    from heterodyne.optimization.cmc.results import CMCResult

logger = get_logger(__name__)


# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------


def _empty_figure(message: str) -> Figure:
    """Return a blank figure containing only a centred text message."""
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.text(
        0.5,
        0.5,
        message,
        ha="center",
        va="center",
        transform=ax.transAxes,
        fontsize=13,
    )
    ax.set_xticks([])
    ax.set_yticks([])
    return fig


def _save_fig(
    fig: Figure,
    save_path: Path | str | None,
    dpi: int = 150,
) -> None:
    """Save and close a figure when a path is provided."""
    if save_path is not None:
        fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
        logger.info("Figure saved to %s", save_path)
        plt.close(fig)


# ---------------------------------------------------------------------------
# Existing plots (kept intact)
# ---------------------------------------------------------------------------


[docs] def plot_posterior( result: CMCResult, params: list[str] | None = None, save_path: Path | str | None = None, figsize: tuple[float, float] | None = None, ) -> Figure: """Plot posterior distributions. Args: result: CMC result with samples params: Parameters to plot (None for all) save_path: Optional save path figsize: Optional figure size Returns: Matplotlib figure """ if result.samples is None: fig, ax = plt.subplots() ax.text(0.5, 0.5, "No samples available", ha="center", va="center") return fig if params is None: params = result.parameter_names n_params = len(params) if n_params == 0: fig, ax = plt.subplots() ax.text(0.5, 0.5, "No parameters to plot", ha="center", va="center") return fig ncols = min(3, n_params) nrows = (n_params + ncols - 1) // ncols if figsize is None: figsize = (4 * ncols, 3 * nrows) fig, axes = plt.subplots(nrows, ncols, figsize=figsize) if n_params == 1: axes = np.array([axes]) axes = axes.ravel() for i, name in enumerate(params): ax = axes[i] if name not in result.samples: ax.text(0.5, 0.5, f"{name}: No samples", ha="center", va="center") continue samples = result.samples[name].ravel() # Histogram ax.hist(samples, bins=50, density=True, alpha=0.7, color="steelblue") # Add mean and credible interval lines idx = result.parameter_names.index(name) mean = result.posterior_mean[idx] std = result.posterior_std[idx] ax.axvline(mean, color="red", linestyle="-", lw=2, label=f"Mean: {mean:.3e}") ax.axvline(mean - std, color="red", linestyle="--", alpha=0.5) ax.axvline(mean + std, color="red", linestyle="--", alpha=0.5) # Credible interval if name in result.credible_intervals: ci = result.credible_intervals[name] if "2.5%" in ci: ax.axvline(ci["2.5%"], color="green", linestyle=":", alpha=0.7) if "97.5%" in ci: ax.axvline(ci["97.5%"], color="green", linestyle=":", alpha=0.7) ax.set_xlabel(name) ax.set_ylabel("Density") if result.r_hat is not None and idx < len(result.r_hat): rhat_str = f" (R-hat={result.r_hat[idx]:.3f})" else: rhat_str = "" ax.set_title(f"{name}{rhat_str}") # Hide unused axes for i in range(n_params, len(axes)): axes[i].set_visible(False) plt.tight_layout() if save_path is not None: fig.savefig(save_path, dpi=150, bbox_inches="tight") plt.close(fig) return fig
[docs] def plot_trace( result: CMCResult, params: list[str] | None = None, save_path: Path | str | None = None, ) -> Figure: """Plot trace plots for MCMC chains. Args: result: CMC result with samples params: Parameters to plot save_path: Optional save path Returns: Matplotlib figure """ if result.samples is None: fig, ax = plt.subplots() ax.text(0.5, 0.5, "No samples available", ha="center", va="center") return fig if params is None: params = result.parameter_names n_params = len(params) if n_params == 0: fig, ax = plt.subplots() ax.text(0.5, 0.5, "No parameters to plot", ha="center", va="center") return fig fig, axes = plt.subplots(n_params, 2, figsize=(12, 3 * n_params)) if n_params == 1: axes = axes.reshape(1, -1) for i, name in enumerate(params): if name not in result.samples: continue samples = result.samples[name] # Trace plot ax_trace = axes[i, 0] if samples.ndim == 2: # Multiple chains for chain in range(samples.shape[0]): ax_trace.plot(samples[chain], alpha=0.7, lw=0.5) else: ax_trace.plot(samples, alpha=0.7, lw=0.5) ax_trace.set_ylabel(name) ax_trace.set_xlabel("Iteration") ax_trace.set_title(f"{name} - Trace") # Posterior histogram ax_hist = axes[i, 1] ax_hist.hist(samples.ravel(), bins=50, density=True, alpha=0.7) # Add statistics idx = result.parameter_names.index(name) mean = result.posterior_mean[idx] ax_hist.axvline(mean, color="red", linestyle="-", lw=2) ax_hist.set_xlabel(name) ax_hist.set_ylabel("Density") ax_hist.set_title(f"{name} - Posterior") plt.tight_layout() if save_path is not None: fig.savefig(save_path, dpi=150, bbox_inches="tight") plt.close(fig) return fig
[docs] def plot_corner( result: CMCResult, params: list[str] | None = None, save_path: Path | str | None = None, ) -> Figure: """Plot corner plot showing parameter correlations. Args: result: CMC result params: Parameters to include save_path: Optional save path Returns: Matplotlib figure """ if result.samples is None: fig, ax = plt.subplots() ax.text(0.5, 0.5, "No samples available", ha="center", va="center") return fig if params is None: params = result.parameter_names n_params = len(params) if n_params == 0: fig, ax = plt.subplots() ax.text(0.5, 0.5, "No parameters to plot", ha="center", va="center") return fig fig, axes = plt.subplots(n_params, n_params, figsize=(2 * n_params, 2 * n_params)) if n_params == 1: axes = np.array([[axes]]) for i, name_i in enumerate(params): for j, name_j in enumerate(params): ax = axes[i, j] if i < j: # Upper triangle: hide ax.set_visible(False) elif i == j: # Diagonal: histogram if name_i in result.samples: ax.hist( result.samples[name_i].ravel(), bins=30, density=True, alpha=0.7 ) ax.set_yticks([]) else: # Lower triangle: scatter/contour if name_i in result.samples and name_j in result.samples: ax.scatter( result.samples[name_j].ravel(), result.samples[name_i].ravel(), alpha=0.1, s=1, ) # Labels if i == n_params - 1: ax.set_xlabel(name_j, fontsize=8) if j == 0 and i > 0: ax.set_ylabel(name_i, fontsize=8) plt.tight_layout() if save_path is not None: fig.savefig(save_path, dpi=150, bbox_inches="tight") plt.close(fig) return fig
# --------------------------------------------------------------------------- # New diagnostic plots # ---------------------------------------------------------------------------
[docs] def plot_forest( samples: dict[str, np.ndarray], param_names: list[str] | None = None, credible_interval: float = 0.94, save_path: Path | str | None = None, figsize: tuple[float, float] | None = None, dpi: int = 150, ) -> Figure: """Forest plot with highest-density interval (HDI) bars. Each parameter is rendered as a horizontal bar spanning its HDI, with a dot at the posterior mean. Multiple chains/shards are overlaid with slight vertical offsets when the sample array is 2-D ``(n_chains, n_samples)``. Args: samples: Dictionary mapping parameter names to arrays of shape ``(n_samples,)`` or ``(n_chains, n_samples)``. param_names: Ordered list of parameter names to include. When ``None``, all keys of ``samples`` are used (sorted). credible_interval: Probability mass for the HDI bars. Must be in ``(0, 1)``. Default is 0.94 (94% HDI). save_path: Optional path to save the figure. figsize: Figure size ``(width, height)``. Auto-computed when ``None``. dpi: Resolution for saved figures. Returns: Matplotlib figure. Raises: ValueError: If ``credible_interval`` is not in ``(0, 1)``. """ if not 0 < credible_interval < 1: raise ValueError( f"credible_interval must be in (0, 1), got {credible_interval}" ) if param_names is None: param_names = sorted(samples.keys()) n_params = len(param_names) if n_params == 0: return _empty_figure("No parameters to plot") alpha_tail = (1.0 - credible_interval) / 2.0 if figsize is None: figsize = (9, max(3, 0.6 * n_params + 1.5)) fig, ax = plt.subplots(figsize=figsize) y_positions = np.arange(n_params, dtype=float) chains: list[np.ndarray] = [] for yi, name in zip(y_positions, param_names, strict=True): if name not in samples: ax.text(0, yi, f"{name}: missing", va="center", fontsize=8, color="gray") continue arr = np.asarray(samples[name]) # Support (n_chains, n_samples) by flattening if arr.ndim == 2: chains = [arr[c].ravel() for c in range(arr.shape[0])] else: chains = [arr.ravel()] colors = plt.colormaps["tab10"](np.linspace(0, 0.9, len(chains))) offsets = np.linspace( -0.15 * (len(chains) - 1), 0.15 * (len(chains) - 1), len(chains) ) for chain_idx, (chain_samples, color, offset) in enumerate( zip(chains, colors, offsets, strict=True) ): lo = float(np.percentile(chain_samples, 100 * alpha_tail)) hi = float(np.percentile(chain_samples, 100 * (1 - alpha_tail))) mean = float(np.mean(chain_samples)) ax.plot([lo, hi], [yi + offset, yi + offset], color=color, lw=3, alpha=0.7) ax.plot( mean, yi + offset, "o", color=color, ms=6, zorder=5, label=f"Chain {chain_idx}" if yi == 0 else "", ) ax.set_yticks(y_positions) ax.set_yticklabels(param_names, fontsize=9) ax.set_xlabel("Parameter value") ax.set_title( f"Forest plot — {int(100 * credible_interval)}% HDI", fontweight="bold", ) ax.grid(True, axis="x", alpha=0.3) ax.invert_yaxis() if len(chains) > 1: ax.legend(loc="lower right", fontsize=8) plt.tight_layout() _save_fig(fig, save_path, dpi=dpi) return fig
[docs] def plot_energy( samples: dict[str, np.ndarray], save_path: Path | str | None = None, figsize: tuple[float, float] = (8, 5), dpi: int = 150, ) -> Figure: """Energy transition vs marginal energy diagnostic plot (NUTS). Compares the distribution of the Hamiltonian energy at each transition (``energy``) with the marginal energy distribution (``energy_diff = energy[1:] - energy[:-1]``). Good mixing produces overlapping distributions; a separated pair indicates poor exploration of the posterior. Args: samples: Sample dictionary. Must contain an ``"energy"`` key with a 1-D array of Hamiltonian energies recorded by NUTS. save_path: Optional save path. figsize: Figure size ``(width, height)``. dpi: Resolution for saved figures. Returns: Matplotlib figure. """ if "energy" not in samples: return _empty_figure( "'energy' key not found in samples.\nRun NUTS with energy tracking enabled." ) energy = np.asarray(samples["energy"]).ravel() energy_diff = np.diff(energy) fig, ax = plt.subplots(figsize=figsize) # Normalise both distributions to the same range for visual comparison bins = min(60, max(10, len(energy) // 10)) ax.hist( energy, bins=bins, density=True, alpha=0.6, color="steelblue", label="Marginal energy H(θ, r)", ) ax.hist( energy_diff, bins=bins, density=True, alpha=0.6, color="tomato", label="Energy transition ΔH", ) ax.set_xlabel("Energy") ax.set_ylabel("Density") ax.set_title( "Energy diagnostic (NUTS)\nOverlap indicates good mixing", fontweight="bold" ) ax.legend() ax.grid(True, alpha=0.3) # Annotate with BFMI: E[ΔH²] / Var(H) (Stan definition, not Var(ΔH)/Var(H)) bfmi = ( float(np.mean(energy_diff**2) / np.var(energy)) if np.var(energy) > 0 else float("nan") ) bfmi_color = "green" if bfmi >= BFMI_THRESHOLD else "red" ax.text( 0.97, 0.97, f"BFMI ≈ {bfmi:.3f}", ha="right", va="top", transform=ax.transAxes, fontsize=10, color=bfmi_color, fontweight="bold", ) plt.tight_layout() _save_fig(fig, save_path, dpi=dpi) return fig
[docs] def plot_autocorrelation( samples: dict[str, np.ndarray], param_names: list[str] | None = None, max_lag: int = 50, save_path: Path | str | None = None, figsize: tuple[float, float] | None = None, dpi: int = 150, ) -> Figure: """Per-parameter autocorrelation function plot. Displays the sample autocorrelation function (ACF) up to ``max_lag`` for each requested parameter. Rapid decay to zero indicates efficient mixing; slow decay indicates high autocorrelation and low effective sample size. Args: samples: Dictionary mapping parameter names to sample arrays of shape ``(n_samples,)`` or ``(n_chains, n_samples)``. param_names: Parameters to include. Defaults to all keys. max_lag: Maximum lag to compute. Clamped to ``n_samples - 1``. save_path: Optional save path. figsize: Figure size. Auto-computed when ``None``. dpi: Resolution for saved figures. Returns: Matplotlib figure. """ if param_names is None: param_names = sorted(samples.keys()) # Drop non-numeric / missing params valid = [p for p in param_names if p in samples] if not valid: return _empty_figure("No valid parameters for autocorrelation plot") n_params = len(valid) ncols = min(3, n_params) nrows = (n_params + ncols - 1) // ncols if figsize is None: figsize = (5 * ncols, 3 * nrows) fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False) axes_flat = axes.ravel() for ax_idx, name in enumerate(valid): ax: Axes = axes_flat[ax_idx] arr = np.asarray(samples[name]) if arr.ndim == 2: arr = arr[0] # Use first chain arr = arr.ravel() n = len(arr) effective_max_lag = min(max_lag, n - 1) arr_centered = arr - arr.mean() var = float(np.var(arr_centered)) if var == 0: ax.axhline(0, color="gray") ax.set_title(f"{name}\n(zero variance)") continue lags = np.arange(effective_max_lag + 1) acf = np.array( [ float(np.mean(arr_centered[: n - lag] * arr_centered[lag:])) / var for lag in lags ] ) ax.bar(lags, acf, color="steelblue", alpha=0.7, width=0.8) # 95% significance bands: ±1.96 / sqrt(n) band = 1.96 / np.sqrt(n) ax.axhline(band, color="red", linestyle="--", lw=1, alpha=0.7) ax.axhline(-band, color="red", linestyle="--", lw=1, alpha=0.7) ax.axhline(0, color="black", lw=0.5) ax.set_xlabel("Lag") ax.set_ylabel("ACF") ax.set_title(f"{name}", fontweight="bold") ax.set_xlim(-0.5, effective_max_lag + 0.5) ax.set_ylim(-1.05, 1.05) ax.grid(True, alpha=0.3) for idx in range(n_params, len(axes_flat)): axes_flat[idx].set_visible(False) fig.suptitle("Sample Autocorrelation Functions", fontsize=13, fontweight="bold") plt.tight_layout() _save_fig(fig, save_path, dpi=dpi) return fig
[docs] def plot_rank_histogram( samples: dict[str, np.ndarray], param_names: list[str] | None = None, save_path: Path | str | None = None, figsize: tuple[float, float] | None = None, dpi: int = 150, ) -> Figure: """Rank histogram (rank plot) for assessing between-chain mixing. For each parameter, the combined samples from all chains are ranked. Each chain's samples are assigned their global ranks, and the rank distribution for each chain is plotted as a histogram. Uniform rank histograms indicate well-mixed chains; U-shaped or spike-tailed histograms indicate convergence problems. Requires sample arrays of shape ``(n_chains, n_samples)``; 1-D arrays are treated as single-chain and produce a trivially uniform plot (logged as a warning). Args: samples: Dictionary mapping parameter names to arrays of shape ``(n_chains, n_samples)`` or ``(n_samples,)``. param_names: Parameters to include. Defaults to all keys. save_path: Optional save path. figsize: Figure size. Auto-computed when ``None``. dpi: Resolution for saved figures. Returns: Matplotlib figure. """ if param_names is None: param_names = sorted(samples.keys()) valid = [p for p in param_names if p in samples] if not valid: return _empty_figure("No valid parameters for rank histogram") n_params = len(valid) ncols = min(3, n_params) nrows = (n_params + ncols - 1) // ncols if figsize is None: figsize = (5 * ncols, 3 * nrows) fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False) axes_flat = axes.ravel() for ax_idx, name in enumerate(valid): ax: Axes = axes_flat[ax_idx] arr = np.asarray(samples[name]) if arr.ndim == 1: logger.warning( "plot_rank_histogram: '%s' is 1-D; rank plot requires " "multiple chains (n_chains, n_samples)", name, ) ax.text( 0.5, 0.5, f"{name}\nSingle chain — no rank plot", ha="center", va="center", transform=ax.transAxes, fontsize=9, ) continue n_chains, n_samples = arr.shape if n_samples < 2: ax.text( 0.5, 0.5, f"{name}\nNeed ≥ 2 draws per chain", ha="center", va="center", transform=ax.transAxes, fontsize=9, ) continue # Compute global ranks across all chains all_samples = arr.ravel() # scipy.stats.rankdata equivalent using argsort order = np.argsort(np.argsort(all_samples)) ranks_per_chain = order.reshape(n_chains, n_samples) n_bins = max(1, min(n_samples // 2, 20)) colors = plt.colormaps["tab10"](np.linspace(0, 0.9, n_chains)) for chain_idx in range(n_chains): ax.hist( ranks_per_chain[chain_idx], bins=n_bins, alpha=max(0.3, 0.8 / n_chains), color=colors[chain_idx], density=True, label=f"Chain {chain_idx}", ) # Expected uniform density: with density=True, each bin's height is # count/(n_total * bin_width). Uniform → expected = n_bins / n_total # where n_total = n_chains * n_samples. expected = n_bins / (n_chains * n_samples) ax.axhline( expected, color="black", linestyle="--", lw=1.5, label="Uniform", ) ax.set_xlabel("Rank") ax.set_ylabel("Density") ax.set_title(f"{name}", fontweight="bold") if n_chains <= 5: ax.legend(fontsize=7) ax.grid(True, alpha=0.3) for idx in range(n_params, len(axes_flat)): axes_flat[idx].set_visible(False) fig.suptitle( "Rank Histograms (chain mixing diagnostic)", fontsize=13, fontweight="bold" ) plt.tight_layout() _save_fig(fig, save_path, dpi=dpi) return fig
[docs] def plot_posterior_predictive( c2_observed: np.ndarray, c2_predicted: np.ndarray, times: np.ndarray, phi_angle: float | None = None, n_samples_overlay: int = 50, save_path: Path | str | None = None, figsize: tuple[float, float] = (14, 5), dpi: int = 150, ) -> Figure: """Posterior predictive check — model vs observed data overlay. Displays three panels: 1. Observed ``c2`` two-time matrix. 2. Posterior predictive mean (average over ``c2_predicted`` samples). 3. Residual ``observed - mean_predicted`` with symmetric colour scale. Args: c2_observed: Observed correlation matrix, shape ``(n_t, n_t)``. c2_predicted: Posterior predictive draws, shape ``(n_posterior, n_t, n_t)`` or ``(n_t, n_t)`` for a single prediction. times: 1-D time axis of length ``n_t``. phi_angle: Optional azimuthal angle in degrees for the title. n_samples_overlay: Number of random diagonal slices to overlay on a fourth panel (0 disables the panel). save_path: Optional save path. figsize: Figure size. dpi: Resolution for saved figures. Returns: Matplotlib figure. Raises: ValueError: If array shapes are inconsistent. """ c2_obs = np.asarray(c2_observed) c2_pred = np.asarray(c2_predicted) if c2_pred.ndim == 2: c2_pred = c2_pred[np.newaxis] if c2_obs.shape != c2_pred.shape[1:]: raise ValueError( f"c2_observed shape {c2_obs.shape} does not match " f"c2_predicted shape {c2_pred.shape[1:]}" ) c2_mean = np.mean(c2_pred, axis=0) residual = c2_obs - c2_mean add_diagonal = n_samples_overlay > 0 and c2_pred.shape[0] > 1 ncols = 4 if add_diagonal else 3 fig, axes = plt.subplots(1, ncols, figsize=(figsize[0] * ncols / 3, figsize[1])) t_extent = [times[0], times[-1], times[-1], times[0]] cmap_data = "viridis" vmin = float(np.nanpercentile(c2_obs, 1)) vmax = float(np.nanpercentile(c2_obs, 99)) im0 = axes[0].imshow( c2_obs, extent=t_extent, aspect="auto", cmap=cmap_data, vmin=vmin, vmax=vmax ) axes[0].set_title("Observed c₂") axes[0].set_xlabel("t₂") axes[0].set_ylabel("t₁") plt.colorbar(im0, ax=axes[0], shrink=0.8) im1 = axes[1].imshow( c2_mean, extent=t_extent, aspect="auto", cmap=cmap_data, vmin=vmin, vmax=vmax ) axes[1].set_title("Posterior predictive mean") axes[1].set_xlabel("t₂") axes[1].set_ylabel("t₁") plt.colorbar(im1, ax=axes[1], shrink=0.8) vmax_res = float(np.nanpercentile(np.abs(residual), 99)) im2 = axes[2].imshow( residual, extent=t_extent, aspect="auto", cmap="RdBu_r", vmin=-vmax_res, vmax=vmax_res, ) axes[2].set_title("Residual (obs - mean)") axes[2].set_xlabel("t₂") axes[2].set_ylabel("t₁") plt.colorbar(im2, ax=axes[2], shrink=0.8) if add_diagonal: ax_diag: Axes = axes[3] rng = np.random.default_rng(0) n_draw = min(n_samples_overlay, c2_pred.shape[0]) chosen = rng.choice(c2_pred.shape[0], size=n_draw, replace=False) for draw_idx in chosen: diag = np.diag(c2_pred[draw_idx]) ax_diag.plot(times, diag, color="steelblue", alpha=0.15, lw=0.8) ax_diag.plot(times, np.diag(c2_obs), color="black", lw=2, label="Observed") ax_diag.plot( times, np.diag(c2_mean), color="red", lw=2, linestyle="--", label="Pred. mean", ) ax_diag.set_xlabel("Time") ax_diag.set_ylabel("c₂(t, t)") ax_diag.set_title(f"Diagonal slices ({n_draw} draws)") ax_diag.legend(fontsize=8) ax_diag.grid(True, alpha=0.3) title = "Posterior Predictive Check" if phi_angle is not None: title += f" [φ = {phi_angle:.1f}°]" fig.suptitle(title, fontsize=13, fontweight="bold") plt.tight_layout() _save_fig(fig, save_path, dpi=dpi) return fig
[docs] def plot_shard_comparison( shard_results: list[dict[str, np.ndarray]], param_names: list[str] | None = None, save_path: Path | str | None = None, figsize: tuple[float, float] | None = None, dpi: int = 150, ) -> Figure: """Cross-shard posterior comparison for CMC diagnostics. Overlays the marginal posterior histograms of each shard for every requested parameter. When shards produce similar posteriors, the histograms overlap; when they diverge, the plot reveals multi-modality or data heterogeneity. Args: shard_results: List of sample dictionaries, one per shard. Each dictionary maps parameter names to 1-D sample arrays. param_names: Parameters to compare. Defaults to keys found in the first shard. save_path: Optional save path. figsize: Figure size. Auto-computed when ``None``. dpi: Resolution for saved figures. Returns: Matplotlib figure. Raises: ValueError: If ``shard_results`` is empty. """ if not shard_results: raise ValueError("shard_results must be non-empty") if param_names is None: param_names = sorted(shard_results[0].keys()) n_params = len(param_names) if n_params == 0: return _empty_figure("No parameters to compare across shards") ncols = min(3, n_params) nrows = (n_params + ncols - 1) // ncols if figsize is None: figsize = (5 * ncols, 3.5 * nrows) fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False) axes_flat = axes.ravel() n_shards = len(shard_results) colors = plt.colormaps["tab10"](np.linspace(0, 0.9, n_shards)) for ax_idx, name in enumerate(param_names): ax: Axes = axes_flat[ax_idx] for shard_idx, shard in enumerate(shard_results): if name not in shard: continue arr = np.asarray(shard[name]).ravel() ax.hist( arr, bins=30, density=True, alpha=0.4, color=colors[shard_idx], label=f"Shard {shard_idx}" if n_shards <= 10 else "", ) ax.set_xlabel(name) ax.set_ylabel("Density") ax.set_title(f"{name}", fontweight="bold") ax.grid(True, alpha=0.3) if ax_idx == 0 and n_shards <= 10: ax.legend(fontsize=7) for idx in range(n_params, len(axes_flat)): axes_flat[idx].set_visible(False) fig.suptitle( f"Cross-shard posterior comparison ({n_shards} shards)", fontsize=13, fontweight="bold", ) plt.tight_layout() _save_fig(fig, save_path, dpi=dpi) return fig
[docs] def plot_divergence_scatter( samples: dict[str, np.ndarray], divergent_mask: np.ndarray, param_pairs: list[tuple[str, str]] | None = None, save_path: Path | str | None = None, figsize: tuple[float, float] | None = None, dpi: int = 150, ) -> Figure: """Scatter plot highlighting divergent transitions. Plots pairs of parameters as scatter plots, colouring divergent transitions in red and non-divergent transitions in grey. Divergent transitions that cluster in parameter space indicate problematic posterior geometry (e.g., funnel-shaped posteriors). Args: samples: Dictionary mapping parameter names to 1-D sample arrays of the same length. divergent_mask: Boolean array of length ``n_samples`` where ``True`` marks a divergent transition. param_pairs: List of ``(param_x, param_y)`` tuples to plot. Defaults to all adjacent pairs of sorted parameter names. save_path: Optional save path. figsize: Figure size. Auto-computed when ``None``. dpi: Resolution for saved figures. Returns: Matplotlib figure. Raises: ValueError: If ``divergent_mask`` length does not match sample length, or if a requested parameter is missing from ``samples``. """ div_mask = np.asarray(divergent_mask, dtype=bool).ravel() # Build default param pairs from sorted keys all_params = sorted(samples.keys()) if param_pairs is None: param_pairs = [ (all_params[i], all_params[i + 1]) for i in range(min(len(all_params) - 1, 5)) ] if not param_pairs: return _empty_figure("No parameter pairs for divergence scatter") ncols = min(3, len(param_pairs)) nrows = (len(param_pairs) + ncols - 1) // ncols if figsize is None: figsize = (4 * ncols, 4 * nrows) fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False) axes_flat = axes.ravel() for ax_idx, (px, py) in enumerate(param_pairs): ax: Axes = axes_flat[ax_idx] if px not in samples: ax.text( 0.5, 0.5, f"Missing: {px}", ha="center", va="center", transform=ax.transAxes, ) continue if py not in samples: ax.text( 0.5, 0.5, f"Missing: {py}", ha="center", va="center", transform=ax.transAxes, ) continue x = np.asarray(samples[px]).ravel() y = np.asarray(samples[py]).ravel() if len(div_mask) != len(x): raise ValueError( f"divergent_mask length {len(div_mask)} does not match " f"sample length {len(x)}" ) non_div = ~div_mask ax.scatter( x[non_div], y[non_div], s=2, alpha=0.3, color="lightgrey", label="Non-divergent", ) n_div = int(div_mask.sum()) if n_div > 0: ax.scatter( x[div_mask], y[div_mask], s=20, alpha=0.8, color="red", zorder=5, label=f"Divergent ({n_div})", ) ax.set_xlabel(px, fontsize=9) ax.set_ylabel(py, fontsize=9) ax.set_title(f"{px} vs {py}", fontweight="bold") ax.legend(fontsize=7) ax.grid(True, alpha=0.2) for idx in range(len(param_pairs), len(axes_flat)): axes_flat[idx].set_visible(False) n_div_total = int(div_mask.sum()) n_total = len(div_mask) fig.suptitle( f"Divergent transitions [{n_div_total}/{n_total} = {100 * n_div_total / max(n_total, 1):.1f}%]", fontsize=13, fontweight="bold", ) plt.tight_layout() _save_fig(fig, save_path, dpi=dpi) return fig
[docs] def plot_rhat_summary( rhat_dict: dict[str, float], threshold: float = 1.01, save_path: Path | str | None = None, figsize: tuple[float, float] | None = None, dpi: int = 150, ) -> Figure: """R-hat bar chart with convergence threshold line. Renders one bar per parameter coloured green (converged, R-hat < ``threshold``) or red (not converged). A dashed horizontal line marks ``threshold``. Args: rhat_dict: Dictionary mapping parameter names to their R-hat scalar values. threshold: Convergence threshold. The standard criterion is R-hat < 1.01 (strict) or < 1.1 (relaxed). Default is 1.01. save_path: Optional save path. figsize: Figure size. Auto-computed when ``None``. dpi: Resolution for saved figures. Returns: Matplotlib figure. Raises: ValueError: If ``rhat_dict`` is empty. """ if not rhat_dict: raise ValueError("rhat_dict must be non-empty") names = list(rhat_dict.keys()) values = [float(rhat_dict[n]) for n in names] n = len(names) if figsize is None: figsize = (max(6, 0.6 * n + 2), 5) fig, ax = plt.subplots(figsize=figsize) colors = ["green" if v < threshold else "tomato" for v in values] x = np.arange(n) ax.bar(x, values, color=colors, alpha=0.8, edgecolor="white") ax.axhline( threshold, color="red", linestyle="--", lw=2, label=f"Threshold = {threshold}", ) # Annotate bars with numeric values for xi, val in zip(x, values, strict=True): ax.text( xi, val + 0.001, f"{val:.3f}", ha="center", va="bottom", fontsize=8, ) ax.set_xticks(x) ax.set_xticklabels(names, rotation=45, ha="right", fontsize=9) ax.set_ylabel("R-hat") ax.set_title("R-hat convergence summary", fontweight="bold") ax.legend() ax.grid(True, axis="y", alpha=0.3) n_converged = sum(1 for v in values if v < threshold) ax.text( 0.02, 0.97, f"Converged: {n_converged}/{n}", ha="left", va="top", transform=ax.transAxes, fontsize=10, color="green" if n_converged == n else "red", ) plt.tight_layout() _save_fig(fig, save_path, dpi=dpi) return fig