"""MCMC convergence diagnostic plots for heterodyne analysis.
Provides visualization of convergence metrics including ESS evolution,
adaptation summaries, and divergence analysis for NUTS sampling.
"""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import numpy as np
from heterodyne.utils.logging import get_logger
if TYPE_CHECKING:
from matplotlib.figure import Figure
from heterodyne.optimization.cmc.results import CMCResult
logger = get_logger(__name__)
# Diagnostic threshold constants
ESS_THRESHOLD = 400
RHAT_THRESHOLD = 1.1
BFMI_THRESHOLD = 0.3
[docs]
def plot_ess_evolution(
result: CMCResult,
save_path: Path | str | None = None,
) -> Figure:
"""Plot effective sample size (ESS) across parameters.
Shows bulk and tail ESS as a grouped bar chart, with a horizontal
reference line at the minimum recommended ESS (400).
Args:
result: CMC result with ESS diagnostics.
save_path: Path to save the figure.
Returns:
Matplotlib Figure.
"""
fig, ax = plt.subplots(figsize=(max(8, len(result.parameter_names) * 0.8), 5))
x = np.arange(len(result.parameter_names))
width = 0.35
if result.ess_bulk is not None:
ax.bar(
x - width / 2,
result.ess_bulk,
width,
label="ESS bulk",
color="C0",
alpha=0.8,
)
if result.ess_tail is not None:
ax.bar(
x + width / 2,
result.ess_tail,
width,
label="ESS tail",
color="C1",
alpha=0.8,
)
ax.axhline(
y=ESS_THRESHOLD,
color="red",
linestyle="--",
alpha=0.5,
label=f"Min recommended ({ESS_THRESHOLD})",
)
ax.set_xticks(x)
ax.set_xticklabels(result.parameter_names, rotation=45, ha="right", fontsize=8)
ax.set_ylabel("Effective Sample Size")
ax.set_title("ESS by Parameter")
ax.legend(fontsize=8)
fig.tight_layout()
if save_path is not None:
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
logger.info("Saved ESS evolution plot to %s", save_path)
plt.close(fig)
return fig
[docs]
def plot_adaptation_summary(
result: CMCResult,
save_path: Path | str | None = None,
) -> Figure:
"""Plot R-hat convergence diagnostic across parameters.
Shows R-hat values as a bar chart with a horizontal reference line
at the convergence threshold (1.1).
Args:
result: CMC result with R-hat diagnostics.
save_path: Path to save the figure.
Returns:
Matplotlib Figure.
"""
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Panel 1: R-hat
ax_rhat = axes[0]
if result.r_hat is not None:
x = np.arange(len(result.parameter_names))
colors = ["red" if rh > RHAT_THRESHOLD else "C0" for rh in result.r_hat]
ax_rhat.bar(x, result.r_hat, color=colors, alpha=0.8)
ax_rhat.axhline(
y=RHAT_THRESHOLD,
color="red",
linestyle="--",
alpha=0.5,
label=f"Threshold ({RHAT_THRESHOLD})",
)
ax_rhat.set_xticks(x)
ax_rhat.set_xticklabels(
result.parameter_names, rotation=45, ha="right", fontsize=8
)
ax_rhat.set_ylabel("R-hat")
ax_rhat.set_title("R-hat by Parameter")
ax_rhat.legend(fontsize=8)
else:
ax_rhat.text(
0.5,
0.5,
"R-hat not available",
ha="center",
va="center",
transform=ax_rhat.transAxes,
)
# Panel 2: BFMI
ax_bfmi = axes[1]
if result.bfmi is not None:
chain_idx = np.arange(len(result.bfmi))
colors = ["red" if b < BFMI_THRESHOLD else "C0" for b in result.bfmi]
ax_bfmi.bar(chain_idx, result.bfmi, color=colors, alpha=0.8)
ax_bfmi.axhline(
y=BFMI_THRESHOLD,
color="red",
linestyle="--",
alpha=0.5,
label=f"Min threshold ({BFMI_THRESHOLD})",
)
ax_bfmi.set_xlabel("Chain")
ax_bfmi.set_ylabel("BFMI")
ax_bfmi.set_title("Bayesian Fraction of Missing Information")
ax_bfmi.legend(fontsize=8)
else:
ax_bfmi.text(
0.5,
0.5,
"BFMI not available",
ha="center",
va="center",
transform=ax_bfmi.transAxes,
)
fig.tight_layout()
if save_path is not None:
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
logger.info("Saved adaptation summary to %s", save_path)
plt.close(fig)
return fig
[docs]
def plot_divergence_scatter(
result: CMCResult,
save_path: Path | str | None = None,
) -> Figure:
"""Plot divergence analysis scatter plot.
If divergence information is available in result metadata, shows a
scatter plot of parameter values for divergent vs non-divergent
transitions. Falls back to a posterior density summary if no
divergence info is available.
Args:
result: CMC result.
save_path: Path to save the figure.
Returns:
Matplotlib Figure.
"""
fig, ax = plt.subplots(figsize=(8, 6))
divergent = result.metadata.get("divergent_transitions")
if (
divergent is not None
and result.samples is not None
and len(result.parameter_names) >= 2
):
divergent = np.asarray(divergent, dtype=bool).ravel()
p1_name = result.parameter_names[0]
p2_name = result.parameter_names[1]
if p1_name not in result.samples or p2_name not in result.samples:
ax.text(
0.5,
0.5,
"Divergence samples unavailable",
ha="center",
va="center",
transform=ax.transAxes,
)
fig.tight_layout()
if save_path is not None:
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
plt.close(fig)
return fig
s1 = np.asarray(result.samples[p1_name]).ravel()
s2 = np.asarray(result.samples[p2_name]).ravel()
# Trim to match if needed
n = min(len(s1), len(s2), len(divergent))
s1, s2, divergent = s1[:n], s2[:n], divergent[:n]
non_div = ~divergent
ax.scatter(
s1[non_div], s2[non_div], alpha=0.1, s=1, color="C0", label="Non-divergent"
)
if np.any(divergent):
ax.scatter(
s1[divergent],
s2[divergent],
alpha=0.8,
s=10,
color="red",
marker="x",
label=f"Divergent ({np.sum(divergent)})",
)
ax.set_xlabel(p1_name)
ax.set_ylabel(p2_name)
ax.legend(fontsize=8)
ax.set_title("Divergence Analysis")
else:
# Fallback: show posterior means with R-hat coloring
if result.r_hat is not None:
x = np.arange(len(result.parameter_names))
scatter = ax.scatter(
x,
result.posterior_mean,
c=result.r_hat,
cmap="RdYlGn_r",
s=80,
edgecolors="black",
linewidths=0.5,
vmin=0.99,
vmax=1.2,
)
ax.errorbar(
x,
result.posterior_mean,
yerr=result.posterior_std,
fmt="none",
ecolor="gray",
alpha=0.5,
)
ax.set_xticks(x)
ax.set_xticklabels(
result.parameter_names, rotation=45, ha="right", fontsize=8
)
plt.colorbar(scatter, ax=ax, label="R-hat")
ax.set_ylabel("Posterior Mean")
ax.set_title("Posterior Summary (colored by R-hat)")
else:
ax.text(
0.5,
0.5,
"No divergence or diagnostic data available",
ha="center",
va="center",
transform=ax.transAxes,
)
fig.tight_layout()
if save_path is not None:
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
logger.info("Saved divergence scatter to %s", save_path)
plt.close(fig)
return fig
[docs]
def plot_kl_divergence_matrix(
result: CMCResult,
save_path: Path | str | None = None,
) -> Figure:
"""Plot pairwise KL divergence heatmap between parameter posteriors.
Computes histogram-based KL divergence for each pair of parameters
using 50-bin histograms with epsilon smoothing.
Args:
result: CMC result with posterior samples.
save_path: Path to save the figure.
Returns:
Matplotlib Figure.
"""
names = result.parameter_names
n_params = len(names)
kl_matrix = np.zeros((n_params, n_params))
eps = 1e-10
n_bins = 50
if result.samples is not None:
for i in range(n_params):
samples_i = np.asarray(result.samples[names[i]]).ravel()
for j in range(n_params):
if i == j:
continue
samples_j = np.asarray(result.samples[names[j]]).ravel()
# Shared range for both histograms
lo = min(float(np.min(samples_i)), float(np.min(samples_j)))
hi = max(float(np.max(samples_i)), float(np.max(samples_j)))
if lo == hi:
continue
bins = np.linspace(lo, hi, n_bins + 1)
# Use density=False, then normalize once to get a true PMF.
# density=True would require a second division by bin_width, not
# by sum(p), causing double-normalization and deflated KL values.
p, _ = np.histogram(samples_i, bins=bins, density=False)
q, _ = np.histogram(samples_j, bins=bins, density=False)
p = p / (np.sum(p) + eps)
q = q / (np.sum(q) + eps)
# KL(p || q)
p_safe = p + eps
q_safe = q + eps
kl_matrix[i, j] = float(np.sum(p_safe * np.log(p_safe / q_safe)))
fig, ax = plt.subplots(figsize=(max(6, n_params * 0.6), max(5, n_params * 0.5)))
# Guard against degenerate matrices. n_params == 0 produces a 0x0 image
# that gives imshow degenerate xlim/ylim and trips matplotlib's
# identical-limits warning; an all-zero matrix similarly forces a singular
# color range. Render an empty-state axes in those cases.
kl_max = float(kl_matrix.max()) if kl_matrix.size else 0.0
if n_params == 0:
ax.text(
0.5,
0.5,
"No parameters to plot",
ha="center",
va="center",
transform=ax.transAxes,
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlim(0.0, 1.0)
ax.set_ylim(0.0, 1.0)
im = None
else:
vmax = kl_max if kl_max > 0.0 else 1.0
im = ax.imshow(kl_matrix, cmap="viridis", aspect="auto", vmin=0.0, vmax=vmax)
ax.set_xticks(np.arange(n_params))
ax.set_yticks(np.arange(n_params))
ax.set_xticklabels(names, rotation=45, ha="right", fontsize=8)
ax.set_yticklabels(names, fontsize=8)
ax.set_title("Pairwise KL Divergence")
if im is not None:
fig.colorbar(im, ax=ax)
fig.tight_layout()
if save_path is not None:
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
logger.info("Saved KL divergence matrix to %s", save_path)
plt.close(fig)
return fig
[docs]
def plot_convergence_diagnostics(
result: CMCResult,
save_path: Path | str | None = None,
) -> Figure:
"""Plot a 2x2 MCMC convergence summary figure.
Panels:
(0,0) ESS bulk bar chart
(0,1) R-hat bar chart with threshold
(1,0) BFMI bar chart per chain with threshold
(1,1) Text summary of key convergence statistics
Args:
result: CMC result with convergence diagnostics.
save_path: Path to save the figure.
Returns:
Matplotlib Figure.
"""
fig, axes = plt.subplots(2, 2, figsize=(12, 9))
fig.suptitle("MCMC Convergence Summary", fontsize=14)
names = result.parameter_names
x = np.arange(len(names))
# Panel (0,0): ESS bulk
ax_ess = axes[0, 0]
if result.ess_bulk is not None:
ax_ess.bar(x, result.ess_bulk, color="C0", alpha=0.8)
ax_ess.axhline(
y=ESS_THRESHOLD,
color="red",
linestyle="--",
alpha=0.5,
label=f"Min recommended ({ESS_THRESHOLD})",
)
ax_ess.set_xticks(x)
ax_ess.set_xticklabels(names, rotation=45, ha="right", fontsize=7)
ax_ess.set_ylabel("ESS (bulk)")
ax_ess.set_title("Effective Sample Size (bulk)")
ax_ess.legend(fontsize=7)
else:
ax_ess.text(
0.5,
0.5,
"ESS not available",
ha="center",
va="center",
transform=ax_ess.transAxes,
)
ax_ess.set_title("Effective Sample Size (bulk)")
# Panel (0,1): R-hat
ax_rhat = axes[0, 1]
if result.r_hat is not None:
colors = ["red" if rh > RHAT_THRESHOLD else "C0" for rh in result.r_hat]
ax_rhat.bar(x, result.r_hat, color=colors, alpha=0.8)
ax_rhat.axhline(
y=RHAT_THRESHOLD,
color="red",
linestyle="--",
alpha=0.5,
label=f"Threshold ({RHAT_THRESHOLD})",
)
ax_rhat.set_xticks(x)
ax_rhat.set_xticklabels(names, rotation=45, ha="right", fontsize=7)
ax_rhat.set_ylabel("R-hat")
ax_rhat.set_title("R-hat by Parameter")
ax_rhat.legend(fontsize=7)
else:
ax_rhat.text(
0.5,
0.5,
"R-hat not available",
ha="center",
va="center",
transform=ax_rhat.transAxes,
)
ax_rhat.set_title("R-hat by Parameter")
# Panel (1,0): BFMI
ax_bfmi = axes[1, 0]
if result.bfmi is not None:
chain_idx = np.arange(len(result.bfmi))
colors_bfmi = ["red" if b < BFMI_THRESHOLD else "C0" for b in result.bfmi]
ax_bfmi.bar(chain_idx, result.bfmi, color=colors_bfmi, alpha=0.8)
ax_bfmi.axhline(
y=BFMI_THRESHOLD,
color="red",
linestyle="--",
alpha=0.5,
label=f"Min threshold ({BFMI_THRESHOLD})",
)
ax_bfmi.set_xlabel("Chain")
ax_bfmi.set_ylabel("BFMI")
ax_bfmi.set_title("Bayesian Fraction of Missing Information")
ax_bfmi.legend(fontsize=7)
else:
ax_bfmi.text(
0.5,
0.5,
"BFMI not available",
ha="center",
va="center",
transform=ax_bfmi.transAxes,
)
ax_bfmi.set_title("Bayesian Fraction of Missing Information")
# Panel (1,1): Text summary
ax_text = axes[1, 1]
ax_text.axis("off")
ax_text.set_title("Convergence Summary")
divergences = result.metadata.get("divergent_transitions")
total_div = int(np.sum(divergences)) if divergences is not None else None
min_ess = float(np.min(result.ess_bulk)) if result.ess_bulk is not None else None
max_rhat = float(np.max(result.r_hat)) if result.r_hat is not None else None
min_bfmi = float(np.min(result.bfmi)) if result.bfmi is not None else None
# Convergence assessment
if max_rhat is not None and min_ess is not None:
converged = max_rhat < RHAT_THRESHOLD and min_ess > ESS_THRESHOLD
assessment = "Converged" if converged else "Not converged"
assess_color = "green" if converged else "red"
else:
assessment = "N/A"
assess_color = "gray"
summary_lines = [
f"Total divergences: {total_div if total_div is not None else 'N/A'}",
f"Min ESS (bulk): {min_ess:.1f}"
if min_ess is not None
else "Min ESS (bulk): N/A",
f"Max R-hat: {max_rhat:.4f}"
if max_rhat is not None
else "Max R-hat: N/A",
f"Min BFMI: {min_bfmi:.4f}"
if min_bfmi is not None
else "Min BFMI: N/A",
]
y_pos = 0.75
for line in summary_lines:
ax_text.text(
0.1,
y_pos,
line,
transform=ax_text.transAxes,
fontsize=11,
fontfamily="monospace",
verticalalignment="top",
)
y_pos -= 0.12
ax_text.text(
0.1,
y_pos,
f"Assessment: {assessment}",
transform=ax_text.transAxes,
fontsize=11,
fontfamily="monospace",
verticalalignment="top",
color=assess_color,
fontweight="bold",
)
fig.tight_layout(rect=(0, 0, 1, 0.95))
if save_path is not None:
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
logger.info("Saved convergence diagnostics to %s", save_path)
plt.close(fig)
return fig