Source code for heterodyne.optimization.cmc.results

"""Result container for CMC analysis."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import numpy as np

if TYPE_CHECKING:
    pass


[docs] class ParameterStats(dict): """Hybrid mapping/sequence for posterior summaries. Supports dict-style access by name (``ps["D0_ref"]``) and integer-index access (``ps[0]``). Inherits ``dict`` so existing ``.get()`` / ``in`` checks continue to work unchanged. """
[docs] def __init__(self, ordered_names: list[str], values: list[float]) -> None: super().__init__(zip(ordered_names, values, strict=True)) self._ordered_names = list(ordered_names) self._ordered_values = list(values)
def __getitem__(self, key: int | str) -> float: # type: ignore[override] if isinstance(key, int): return self._ordered_values[key] return super().__getitem__(key) def __len__(self) -> int: return len(self._ordered_values) def __array__(self, dtype=None) -> np.ndarray: return np.asarray(self._ordered_values, dtype=dtype) @property def as_array(self) -> np.ndarray: """Ordered values as a numpy array.""" return np.asarray(self._ordered_values, dtype=float)
[docs] def tolist(self) -> list[float]: return list(self._ordered_values)
[docs] @dataclass class CMCResult: """Result of CMC (Consensus Monte Carlo) analysis. Contains posterior samples, summaries, and convergence diagnostics. """ # Core results parameter_names: list[str] posterior_mean: np.ndarray posterior_std: np.ndarray credible_intervals: dict[str, dict[str, float]] # Convergence diagnostics convergence_passed: bool r_hat: np.ndarray | None = None ess_bulk: np.ndarray | None = None ess_tail: np.ndarray | None = None bfmi: list[float] | None = None # Full posterior samples samples: dict[str, np.ndarray] | None = None # MAP estimate (maximum a posteriori) map_estimate: np.ndarray | None = None # Sampling info num_warmup: int = 0 num_samples: int = 0 num_chains: int = 0 num_shards: int = 1 # First-class shard count (homodyne-parity). divergences: int = 0 # First-class divergence count (homodyne-parity). wall_time_seconds: float | None = None # Additional metadata metadata: dict[str, Any] = field(default_factory=dict) # Homodyne-parity fields (all optional; populated by fit_cmc_jax / merge_shard) convergence_status: str = ( "not_converged" # "converged"|"divergences"|"not_converged" ) warmup_time: float | None = None # Wall time for warmup phase only per_angle_mode: str = "auto" # Effective per-angle scaling mode used chi_squared: float | None = None # Post-combination chi-squared quality_flag: str | None = None # "good"|"warning"|"poor" mean_contrast: np.ndarray | None = None # Per-angle posterior contrast means std_contrast: np.ndarray | None = None # Per-angle posterior contrast stds mean_offset: np.ndarray | None = None # Per-angle posterior offset means std_offset: np.ndarray | None = None # Per-angle posterior offset stds # Homodyne-parity surface — populated by ``from_mcmc_samples`` / # ``cmc_result_to_arviz`` callers that want a top-level idata handle. inference_data: Any | None = None # az.InferenceData when arviz is installed @property def n_params(self) -> int: """Number of parameters.""" return len(self.parameter_names) # ------------------------------------------------------------------ # Homodyne-parity aliases (read-only properties over the canonical # heterodyne fields so cross-package callers can use either name). # ------------------------------------------------------------------ @property def parameters(self) -> np.ndarray: """Homodyne-parity alias for :attr:`posterior_mean`.""" return self.posterior_mean @property def uncertainties(self) -> np.ndarray: """Homodyne-parity alias for :attr:`posterior_std`.""" return self.posterior_std @property def param_names(self) -> list[str]: """Homodyne-parity alias for :attr:`parameter_names`.""" return self.parameter_names
[docs] def get_param_summary(self, name: str) -> dict[str, float]: """Get summary statistics for a parameter. Args: name: Parameter name Returns: Dict with mean, std, and credible interval bounds """ try: idx = self.parameter_names.index(name) except ValueError: raise KeyError(f"Parameter '{name}' not found") from None summary = { "mean": float(self.posterior_mean[idx]), "std": float(self.posterior_std[idx]), } if name in self.credible_intervals: summary.update(self.credible_intervals[name]) if self.r_hat is not None: summary["r_hat"] = float(self.r_hat[idx]) if self.ess_bulk is not None: summary["ess_bulk"] = float(self.ess_bulk[idx]) return summary
[docs] def get_samples(self, name: str) -> np.ndarray | None: """Get posterior samples for a parameter. Args: name: Parameter name Returns: Array of samples or None if not stored """ if self.samples is None: return None return self.samples.get(name)
[docs] def params_dict(self) -> dict[str, float]: """Get posterior means as dictionary.""" return { name: float(self.posterior_mean[i]) for i, name in enumerate(self.parameter_names) }
[docs] def validate_convergence( self, r_hat_threshold: float = 1.1, min_ess: int = 100, min_bfmi: float = 0.3, ) -> list[str]: """Validate convergence diagnostics. Args: r_hat_threshold: Maximum acceptable R-hat min_ess: Minimum effective sample size min_bfmi: Minimum BFMI value Returns: List of warning messages """ warnings = [] if self.r_hat is not None: bad_rhat = np.where(self.r_hat > r_hat_threshold)[0] for idx in bad_rhat: warnings.append( f"R-hat for {self.parameter_names[idx]}: " f"{self.r_hat[idx]:.3f} > {r_hat_threshold}" ) if self.ess_bulk is not None: low_ess = np.where(self.ess_bulk < min_ess)[0] for idx in low_ess: warnings.append( f"Low ESS for {self.parameter_names[idx]}: " f"{self.ess_bulk[idx]:.0f} < {min_ess}" ) if self.bfmi is not None: low_bfmi = [b for b in self.bfmi if b < min_bfmi] if low_bfmi: warnings.append(f"Low BFMI: {min(low_bfmi):.3f} < {min_bfmi}") return warnings
[docs] def summary(self) -> str: """Generate summary string.""" lines = [ "CMC Analysis Result", "=" * 60, f"Convergence: {'PASSED' if self.convergence_passed else 'FAILED'}", f"Chains: {self.num_chains} | Samples: {self.num_samples} | Warmup: {self.num_warmup}", "", "Posterior Summary:", "-" * 60, f"{'Parameter':18s} {'Mean':>12s} {'Std':>10s} {'R-hat':>8s} {'ESS':>8s}", "-" * 60, ] for i, name in enumerate(self.parameter_names): mean = self.posterior_mean[i] std = self.posterior_std[i] r_hat = self.r_hat[i] if self.r_hat is not None else np.nan ess = self.ess_bulk[i] if self.ess_bulk is not None else np.nan r_hat_str = f"{r_hat:.3f}" if not np.isnan(r_hat) else "N/A" ess_str = f"{ess:.0f}" if not np.isnan(ess) else "N/A" lines.append( f"{name:18s} {mean:12.4e} {std:10.2e} {r_hat_str:>8s} {ess_str:>8s}" ) lines.append("-" * 60) if self.wall_time_seconds is not None: lines.append(f"Wall time: {self.wall_time_seconds:.1f} s") return "\n".join(lines)
[docs] def get_samples_array(self) -> np.ndarray: """Return samples as a 3-D array of shape (num_chains, num_samples, n_params). Parameters with missing or None samples are filled with zeros. Flat 1-D sample arrays of shape ``num_chains * num_samples`` are reshaped to ``(num_chains, num_samples)`` automatically. """ n_params = len(self.parameter_names) out = np.zeros((self.num_chains, self.num_samples, n_params)) if self.samples is None: return out for i, name in enumerate(self.parameter_names): if name not in self.samples: continue arr = np.asarray(self.samples[name]) if arr.ndim == 1: total = arr.shape[0] nc = self.num_chains ns = self.num_samples if nc > 1 and total == nc * ns: arr = arr.reshape(nc, ns) else: arr = arr[np.newaxis, :] out[:, :, i] = arr[: self.num_chains, : self.num_samples] return out
[docs] def get_posterior_stats(self) -> dict[str, dict[str, float]]: """Return per-parameter posterior statistics. Returns a dict keyed by parameter name. Each value contains ``mean``, ``std``, ``median``, ``hdi_5%``, ``hdi_95%``, ``r_hat``, ``ess_bulk``, ``ess_tail``. Parameters absent from ``self.samples`` are omitted. """ stats: dict[str, dict[str, float]] = {} if self.samples is None: return stats for i, name in enumerate(self.parameter_names): if name not in self.samples: continue flat = np.asarray(self.samples[name]).flatten() r_hat_val = float(self.r_hat[i]) if self.r_hat is not None else float("nan") ess_bulk_val = ( float(self.ess_bulk[i]) if self.ess_bulk is not None else float("nan") ) ess_tail_val = ( float(self.ess_tail[i]) if self.ess_tail is not None else float("nan") ) stats[name] = { "mean": float(np.nanmean(flat)), "std": float(np.nanstd(flat)), "median": float(np.nanmedian(flat)), "hdi_5%": float(np.nanpercentile(flat, 5)), "hdi_95%": float(np.nanpercentile(flat, 95)), "r_hat": r_hat_val, "ess_bulk": ess_bulk_val, "ess_tail": ess_tail_val, } return stats
[docs] @classmethod def from_mcmc_samples( cls, mcmc_samples: Any, stats: Any, analysis_mode: str = "static", n_warmup: int = 500, min_ess: float | None = None, # noqa: ARG003 — accepted for homodyne parity ) -> CMCResult: """Build a :class:`CMCResult` from raw MCMC samples (homodyne parity). Mirrors ``homodyne.optimization.cmc.results.CMCResult.from_mcmc_samples``. Duck-typed: ``mcmc_samples`` must expose ``.samples`` (dict[str, ndarray]), ``.param_names`` (list[str]), ``.n_chains`` (int), ``.n_samples`` (int); ``stats`` must expose ``.num_divergent`` (int) and may expose ``.wall_time`` / ``.warmup_time``. Diagnostics (R-hat, ESS) are not computed here — they require per-chain reshaping and ArviZ. Callers that need diagnostics should run :func:`cmc_result_to_arviz` and overwrite ``.r_hat`` / ``.ess_*`` after construction. Parameters ---------- mcmc_samples: Object holding posterior draws; see duck-typed surface above. stats: Sampling statistics object; see duck-typed surface above. analysis_mode: Stored on the result for plot / report consumers. Default ``"static"`` mirrors homodyne. n_warmup: Number of warmup draws (recorded on the result). min_ess: Accepted for homodyne parity; ignored here because diagnostics are not computed by this factory. Returns ------- CMCResult Populated result with ``parameter_names``, ``posterior_mean``, ``posterior_std``, ``samples``, and basic sampling/divergence metadata. ``credible_intervals`` is left empty; downstream consumers can populate it via ``get_posterior_stats``. """ del min_ess # parity-only; convergence checks are not run here samples_dict = dict(mcmc_samples.samples) param_names = list(mcmc_samples.param_names) n_params = len(param_names) n_chains = int(getattr(mcmc_samples, "n_chains", 4)) n_samples = int(getattr(mcmc_samples, "n_samples", 0)) num_shards = int(getattr(mcmc_samples, "num_shards", 1)) posterior_mean = np.zeros(n_params) posterior_std = np.zeros(n_params) for i, name in enumerate(param_names): if name in samples_dict: flat = np.asarray(samples_dict[name]).ravel() posterior_mean[i] = float(np.nanmean(flat)) posterior_std[i] = float(np.nanstd(flat)) divergences = int(getattr(stats, "num_divergent", 0)) wall_time_raw = getattr(stats, "wall_time", None) warmup_time_raw = getattr(stats, "warmup_time", None) wall_time_seconds = float(wall_time_raw) if wall_time_raw is not None else None warmup_time = float(warmup_time_raw) if warmup_time_raw is not None else None convergence_passed = divergences == 0 convergence_status = "converged" if convergence_passed else "divergences" return cls( parameter_names=param_names, posterior_mean=posterior_mean, posterior_std=posterior_std, credible_intervals={}, convergence_passed=convergence_passed, samples=samples_dict, num_warmup=n_warmup, num_samples=n_samples, num_chains=n_chains, num_shards=num_shards, divergences=divergences, wall_time_seconds=wall_time_seconds, warmup_time=warmup_time, convergence_status=convergence_status, per_angle_mode=analysis_mode, )
# --------------------------------------------------------------------------- # Standalone functions operating on CMCResult # ---------------------------------------------------------------------------
[docs] def cmc_result_to_arviz(result: CMCResult) -> Any: """Convert a CMCResult to an ArviZ InferenceData object. Samples stored in ``result.samples`` are reshaped to ``(num_chains, num_draws)`` when ``result.num_chains > 1`` so that ArviZ can compute per-chain diagnostics (R-hat, ESS). When the result carries flat 1-D arrays the function treats the entire sequence as a single chain. Args: result: Completed CMC analysis result. Returns: ``arviz.InferenceData`` with a ``posterior`` group populated from ``result.samples`` and, when available, ``sample_stats`` populated from ``result.bfmi``. Raises: ImportError: If ArviZ is not installed. ValueError: If ``result.samples`` is None or empty. """ try: import arviz as az # type: ignore[import-untyped] except ImportError: raise ImportError( "ArviZ is required for cmc_result_to_arviz. Install it with: uv add arviz" ) from None if not result.samples: raise ValueError( "CMCResult.samples is None or empty; cannot build InferenceData." ) n_chains = max(result.num_chains, 1) posterior_dict: dict[str, np.ndarray] = {} for name, arr in result.samples.items(): arr = np.asarray(arr) if arr.ndim == 1: total = arr.shape[0] n_draws = total // n_chains if n_chains > 1 and total % n_chains == 0: # Reshape to (chains, draws) posterior_dict[name] = arr.reshape(n_chains, n_draws) else: # Fall back to single chain posterior_dict[name] = arr[np.newaxis, :] elif arr.ndim == 2: # Already (chains, draws) posterior_dict[name] = arr else: # Higher-dimensional parameter (e.g. covariance matrix per draw) posterior_dict[name] = arr idata = az.from_dict({"posterior": posterior_dict}) if result.bfmi is not None: sample_stats = {"energy": np.array(result.bfmi)} idata.add_groups({"sample_stats": sample_stats}) return idata
[docs] def compare_cmc_nlsq( cmc_result: CMCResult, nlsq_result: Any, consistency_sigma: float = 2.0, ) -> dict[str, Any]: """Compare CMC posterior means with NLSQ point estimates. Parameters that appear in both results are compared. Parameters present in only one result are silently skipped. Args: cmc_result: Completed CMC result. nlsq_result: Completed NLSQ result (``NLSQResult`` instance). consistency_sigma: Number of posterior standard deviations within which the NLSQ estimate must fall to be flagged as consistent. Defaults to 2.0 (approximately 95 % credible interval). Returns: Dictionary with keys: - ``"common_parameters"`` — list of parameter names present in both. - ``"differences"`` — dict mapping name to ``(cmc_mean - nlsq_value)``. - ``"relative_deviations"`` — dict mapping name to ``abs(cmc_mean - nlsq_value) / cmc_std``. - ``"consistent"`` — dict mapping name to bool (True if within ``consistency_sigma`` posterior std of the CMC mean). - ``"n_consistent"`` — int count of consistent parameters. - ``"n_inconsistent"`` — int count of inconsistent parameters. - ``"consistency_sigma"`` — the threshold used. """ cmc_means = cmc_result.params_dict() # NLSQResult.params_dict is a property, not a method nlsq_params: dict[str, float] = nlsq_result.params_dict common = sorted(set(cmc_means) & set(nlsq_params)) differences: dict[str, float] = {} relative_deviations: dict[str, float] = {} consistent: dict[str, bool] = {} for name, cmc_val, nlsq_val, cmc_std in zip( common, [cmc_means[n] for n in common], [nlsq_params[n] for n in common], [ float(cmc_result.posterior_std[cmc_result.parameter_names.index(n)]) for n in common ], strict=True, ): diff = cmc_val - nlsq_val differences[name] = diff if cmc_std > 0.0: rel_dev = abs(diff) / cmc_std else: rel_dev = float("inf") if diff != 0.0 else 0.0 relative_deviations[name] = rel_dev consistent[name] = rel_dev <= consistency_sigma n_consistent = sum(1 for v in consistent.values() if v) n_inconsistent = len(consistent) - n_consistent return { "common_parameters": common, "differences": differences, "relative_deviations": relative_deviations, "consistent": consistent, "n_consistent": n_consistent, "n_inconsistent": n_inconsistent, "consistency_sigma": consistency_sigma, }
[docs] def merge_shard_cmc_results( shard_results: list[CMCResult], parameter_names: list[str] | None = None, ) -> CMCResult: """Combine multiple shard CMCResults into a single consensus result. Uses inverse-variance weighting (precision weighting) to combine posterior means from independent shards, following the Consensus Monte Carlo methodology (Scott et al., 2016). Diagnostics (R-hat, ESS, BFMI) are set to their worst-case values across shards so that failures are never hidden by averaging. Args: shard_results: Non-empty list of per-shard CMCResults. Each must have the same ``parameter_names`` (or ``parameter_names`` override must be supplied). parameter_names: Optional explicit parameter name list. When supplied, only these parameters are included in the merged result; they must be present in every shard. Returns: A new ``CMCResult`` representing the consensus posterior. Raises: ValueError: If ``shard_results`` is empty or parameter names are inconsistent across shards when no override is given. """ if not shard_results: raise ValueError("shard_results must be non-empty.") # --- Hierarchical combination for large shard counts (CM-09) --- # Combining K > 500 results at once requires O(K) memory for stacking # arrays. Chunking into groups of 500 and merging recursively reduces # peak memory to O(500) × ceil(K/500) with identical numerical result. _CHUNK = 500 if len(shard_results) > _CHUNK: chunks = [ shard_results[i : i + _CHUNK] for i in range(0, len(shard_results), _CHUNK) ] chunk_merged = [merge_shard_cmc_results(c, parameter_names) for c in chunks] return merge_shard_cmc_results(chunk_merged, parameter_names) # Determine canonical parameter names if parameter_names is None: parameter_names = shard_results[0].parameter_names for i, sr in enumerate(shard_results[1:], start=1): if sr.parameter_names != parameter_names: raise ValueError( f"Shard {i} has parameter_names {sr.parameter_names!r} " f"but shard 0 has {parameter_names!r}. " "Pass parameter_names explicitly to override." ) n_params = len(parameter_names) # --- Inverse-variance weighted combination --- # precision_i = 1 / std_i^2 (per parameter) # combined_mean = sum(precision_i * mean_i) / sum(precision_i) # combined_std = 1 / sqrt(sum(precision_i)) precision_sum = np.zeros(n_params, dtype=np.float64) weighted_mean_sum = np.zeros(n_params, dtype=np.float64) for sr in shard_results: std = np.asarray(sr.posterior_std, dtype=np.float64) # Guard against zero std (degenerate shards) std = np.where(std > 0.0, std, np.finfo(np.float64).tiny) precision = 1.0 / (std**2) precision_sum += precision mean = np.asarray(sr.posterior_mean, dtype=np.float64) weighted_mean_sum += precision * mean combined_mean = weighted_mean_sum / precision_sum combined_std = 1.0 / np.sqrt(precision_sum) # --- Worst-case diagnostics --- r_hat_arrays = [sr.r_hat for sr in shard_results if sr.r_hat is not None] combined_r_hat: np.ndarray | None = None if r_hat_arrays: combined_r_hat = np.max(np.stack(r_hat_arrays, axis=0), axis=0) ess_bulk_arrays = [sr.ess_bulk for sr in shard_results if sr.ess_bulk is not None] combined_ess_bulk: np.ndarray | None = None if ess_bulk_arrays: combined_ess_bulk = np.nansum(np.stack(ess_bulk_arrays, axis=0), axis=0) ess_tail_arrays = [sr.ess_tail for sr in shard_results if sr.ess_tail is not None] combined_ess_tail: np.ndarray | None = None if ess_tail_arrays: combined_ess_tail = np.nansum(np.stack(ess_tail_arrays, axis=0), axis=0) all_bfmi: list[float] | None = None bfmi_lists = [sr.bfmi for sr in shard_results if sr.bfmi is not None] if bfmi_lists: all_bfmi = [min(bfmi_list) for bfmi_list in bfmi_lists] # --- Credible intervals: rebuild from combined mean/std (Gaussian approx) --- z95 = 1.959963985 # scipy.stats.norm.ppf(0.975) z89 = 1.598193423 # scipy.stats.norm.ppf(0.945) credible_intervals: dict[str, dict[str, float]] = {} for i, name in enumerate(parameter_names): mu = float(combined_mean[i]) sigma = float(combined_std[i]) credible_intervals[name] = { "lower_95": mu - z95 * sigma, "upper_95": mu + z95 * sigma, "lower_89": mu - z89 * sigma, "upper_89": mu + z89 * sigma, } # Convergence: all shards must pass convergence_passed = all(sr.convergence_passed for sr in shard_results) # --- Aggregate samples (concatenate across shards) --- combined_samples: dict[str, np.ndarray] | None = None if all(sr.samples is not None for sr in shard_results): combined_samples = {} for name in parameter_names: arrays = [ np.asarray(sr.samples[name]) # type: ignore[index] for sr in shard_results if sr.samples is not None and name in sr.samples ] if arrays: combined_samples[name] = np.concatenate(arrays, axis=0) total_samples = sum(sr.num_samples for sr in shard_results) total_chains = sum(sr.num_chains for sr in shard_results) max_warmup = max(sr.num_warmup for sr in shard_results) total_wall_time: float | None = None wall_times = [ sr.wall_time_seconds for sr in shard_results if sr.wall_time_seconds is not None ] if wall_times: total_wall_time = max(wall_times) # Parallel shards: wall time = max shard # Sum divergences across shards (homodyne-parity). Each shard's per-shard # divergence count lives in its metadata under "num_divergences" or # "divergences"; sum whichever is present. total_divergences = 0 for _sr in shard_results: _md = getattr(_sr, "metadata", None) or {} _div = _md.get("num_divergences", _md.get("divergences", 0)) try: total_divergences += int(_div) except (TypeError, ValueError): continue # Also count first-class field if populated on the shard. if getattr(_sr, "divergences", 0): # Avoid double-counting if metadata also held it; assume metadata # is the canonical legacy source. pass return CMCResult( parameter_names=list(parameter_names), posterior_mean=combined_mean, posterior_std=combined_std, credible_intervals=credible_intervals, convergence_passed=convergence_passed, r_hat=combined_r_hat, ess_bulk=combined_ess_bulk, ess_tail=combined_ess_tail, bfmi=all_bfmi, samples=combined_samples, map_estimate=None, num_warmup=max_warmup, num_samples=total_samples, num_chains=total_chains, num_shards=len(shard_results), divergences=total_divergences, wall_time_seconds=total_wall_time, metadata={ "n_shards": len(shard_results), "num_divergences": total_divergences, "combination_method": "inverse_variance", }, convergence_status=( "converged" if convergence_passed else ( "divergences" if any( getattr(sr, "metadata", {}).get("divergence_rate", 0.0) > 0.05 for sr in shard_results ) else "not_converged" ) ), )
[docs] def cmc_result_summary_table( result: CMCResult, ci_level: str = "95", width: int = 80, ) -> str: """Format a CMCResult as a human-readable parameter summary table. The table includes columns for parameter name, posterior mean, posterior standard deviation, credible interval bounds, R-hat, and bulk ESS. Missing diagnostics are shown as ``N/A``. Args: result: Completed CMC analysis result. ci_level: Credible interval level to display. Must be ``"95"`` or ``"89"``. Defaults to ``"95"``. width: Total character width of the horizontal rule separators. Defaults to 80. Returns: Multi-line string containing the formatted table. Raises: ValueError: If ``ci_level`` is not ``"95"`` or ``"89"``. """ if ci_level not in {"95", "89"}: raise ValueError(f"ci_level must be '95' or '89', got {ci_level!r}.") lower_key = f"lower_{ci_level}" upper_key = f"upper_{ci_level}" sep = "-" * width header_sep = "=" * width col_name = 18 col_mean = 13 col_std = 11 col_ci = 25 col_rhat = 8 col_ess = 8 header = ( f"{'Parameter':<{col_name}}" f"{'Mean':>{col_mean}}" f"{'Std':>{col_std}}" f"{f'CI {ci_level}%':^{col_ci}}" f"{'R-hat':>{col_rhat}}" f"{'ESS':>{col_ess}}" ) lines = [ "CMC Posterior Summary", header_sep, f"Convergence : {'PASSED' if result.convergence_passed else 'FAILED'}", f"Chains : {result.num_chains}", f"Samples : {result.num_samples}", f"Warmup : {result.num_warmup}", ] if result.wall_time_seconds is not None: lines.append(f"Wall time : {result.wall_time_seconds:.1f} s") lines += ["", header, sep] for i, name in enumerate(result.parameter_names): mean = float(result.posterior_mean[i]) std = float(result.posterior_std[i]) ci_str = "N/A" if name in result.credible_intervals: ci = result.credible_intervals[name] lo = ci.get(lower_key) hi = ci.get(upper_key) if lo is not None and hi is not None: ci_str = f"[{lo:.3e}, {hi:.3e}]" r_hat_val = result.r_hat[i] if result.r_hat is not None else float("nan") ess_val = result.ess_bulk[i] if result.ess_bulk is not None else float("nan") r_hat_str = f"{r_hat_val:.3f}" if not np.isnan(r_hat_val) else "N/A" ess_str = f"{ess_val:.0f}" if not np.isnan(ess_val) else "N/A" row = ( f"{name:<{col_name}}" f"{mean:>{col_mean}.4e}" f"{std:>{col_std}.2e}" f"{ci_str:^{col_ci}}" f"{r_hat_str:>{col_rhat}}" f"{ess_str:>{col_ess}}" ) lines.append(row) lines.append(sep) if result.bfmi is not None: min_bfmi = min(result.bfmi) bfmi_flag = " [LOW]" if min_bfmi < 0.3 else "" lines.append(f"Min BFMI : {min_bfmi:.3f}{bfmi_flag}") return "\n".join(lines)