Source code for heterodyne.viz.mcmc_report

"""MCMC analysis report generation for heterodyne analysis.

Generates comprehensive Markdown reports summarizing NLSQ and CMC
results including parameter tables, convergence diagnostics, and
fit quality metrics.
"""

from __future__ import annotations

from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import TYPE_CHECKING

from heterodyne.utils.logging import get_logger

if TYPE_CHECKING:
    from heterodyne.optimization.cmc.results import CMCResult
    from heterodyne.optimization.nlsq.results import NLSQResult

logger = get_logger(__name__)


[docs] @dataclass class ReportConfig: """Configuration for report generation. Attributes: include_diagnostics: Include convergence diagnostic tables. include_timing: Include timing information. include_correlation: Include parameter correlation analysis. ci_level: Credible interval level ("95" or "89"). float_precision: Decimal places for floating-point values. """ include_diagnostics: bool = True include_timing: bool = True include_correlation: bool = True ci_level: str = "95" float_precision: int = 4
[docs] def generate_report( nlsq_results: list[NLSQResult] | None = None, cmc_results: list[CMCResult] | None = None, output_dir: Path | str | None = None, config: ReportConfig | None = None, ) -> Path | str: """Generate a comprehensive Markdown analysis report. Creates a structured summary of the fitting results including: - Parameter estimates with uncertainties - Convergence diagnostics (R-hat, ESS, BFMI) - Fit quality metrics (chi-squared, cost) - Timing and configuration summary Args: nlsq_results: List of NLSQ results (one per phi angle). cmc_results: List of CMC results (one per phi angle). output_dir: Directory to write the report file. If None, returns the report as a string instead of writing. config: Report configuration. Uses defaults if None. Returns: Path to the written report file, or the report string if output_dir is None. """ if config is None: config = ReportConfig() sections: list[str] = [] # Header timestamp = datetime.now(tz=UTC).strftime("%Y-%m-%d %H:%M:%S UTC") sections.append(f"# Heterodyne Analysis Report\n\nGenerated: {timestamp}\n") # NLSQ Results if nlsq_results: sections.append("## NLSQ Optimization Results\n") for result in nlsq_results: phi = result.metadata.get("phi_angle", "N/A") sections.append(f"### Phi = {phi}\n") sections.append(_format_nlsq_table(result, config)) if config.include_timing and result.wall_time_seconds is not None: sections.append(f"\nWall time: {result.wall_time_seconds:.2f} s\n") # CMC Results if cmc_results: sections.append("## CMC Bayesian Results\n") for cmc_result in cmc_results: phi = cmc_result.metadata.get("phi_angle", "N/A") sections.append(f"### Phi = {phi}\n") sections.append(_format_cmc_table(cmc_result, config)) if config.include_diagnostics: sections.append(_format_diagnostics(cmc_result, config)) if config.include_timing and cmc_result.wall_time_seconds is not None: sections.append(f"\nWall time: {cmc_result.wall_time_seconds:.1f} s\n") report_text = "\n".join(sections) if output_dir is None: return report_text output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) report_path = output_dir / "analysis_report.md" report_path.write_text(report_text, encoding="utf-8") logger.info("Generated analysis report: %s", report_path) return report_path
def _format_nlsq_table(result: NLSQResult, config: ReportConfig) -> str: """Format NLSQ parameter table in Markdown.""" prec = config.float_precision lines = [ "| Parameter | Value | Uncertainty |", "|-----------|-------|-------------|", ] for i, name in enumerate(result.parameter_names): val = result.parameters[i] unc = result.uncertainties[i] if result.uncertainties is not None else None unc_str = f"{unc:.{prec}e}" if unc is not None else "N/A" lines.append(f"| {name} | {val:.{prec}e} | {unc_str} |") lines.append("") if result.reduced_chi_squared is not None: lines.append(f"Reduced chi-squared: {result.reduced_chi_squared:.{prec}f}") if result.final_cost is not None: lines.append(f"Final cost: {result.final_cost:.{prec}e}") lines.append(f"Success: {result.success}") lines.append(f"Function evaluations: {result.n_function_evals}") return "\n".join(lines) + "\n" def _format_cmc_table(result: CMCResult, config: ReportConfig) -> str: """Format CMC posterior summary table in Markdown.""" prec = config.float_precision # credible_intervals stores ArviZ-style percentile keys, e.g. "2.5%" / "97.5%". ci_half = (100 - int(config.ci_level)) / 2 ci_key_lo = f"{ci_half:.1f}%" ci_key_hi = f"{100 - ci_half:.1f}%" lines = [ f"| Parameter | Mean | Std | CI {config.ci_level}% |", "|-----------|------|-----|---------|", ] for i, name in enumerate(result.parameter_names): mean = float(result.posterior_mean[i]) std = float(result.posterior_std[i]) ci_str = "N/A" if name in result.credible_intervals: ci = result.credible_intervals[name] lo = ci.get(ci_key_lo) hi = ci.get(ci_key_hi) if lo is not None and hi is not None: ci_str = f"[{lo:.{prec}e}, {hi:.{prec}e}]" lines.append(f"| {name} | {mean:.{prec}e} | {std:.{prec}e} | {ci_str} |") lines.append("") lines.append(f"Convergence: {'PASSED' if result.convergence_passed else 'FAILED'}") lines.append( f"Chains: {result.num_chains} | Samples: {result.num_samples} | Warmup: {result.num_warmup}" ) return "\n".join(lines) + "\n" def _format_diagnostics(result: CMCResult, config: ReportConfig) -> str: """Format convergence diagnostics table in Markdown.""" lines = [ "\n#### Convergence Diagnostics\n", "| Parameter | R-hat | ESS bulk | ESS tail |", "|-----------|-------|----------|----------|", ] for i, name in enumerate(result.parameter_names): r_hat = ( f"{result.r_hat[i]:.3f}" if result.r_hat is not None and i < len(result.r_hat) else "N/A" ) ess_b = ( f"{result.ess_bulk[i]:.0f}" if result.ess_bulk is not None and i < len(result.ess_bulk) else "N/A" ) ess_t = ( f"{result.ess_tail[i]:.0f}" if result.ess_tail is not None and i < len(result.ess_tail) else "N/A" ) lines.append(f"| {name} | {r_hat} | {ess_b} | {ess_t} |") if result.bfmi is not None: lines.append(f"\nMin BFMI: {min(result.bfmi):.3f}") return "\n".join(lines) + "\n"