"""Optimization execution for heterodyne CLI.
Manages NLSQ and CMC fitting runs, including warm-start resolution.
"""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
from heterodyne.io.mcmc_writers import format_mcmc_summary, save_mcmc_results
from heterodyne.io.nlsq_writers import (
format_nlsq_summary,
save_nlsq_json_files,
save_nlsq_npz_file,
)
from heterodyne.optimization.cmc import CMCConfig
from heterodyne.optimization.cmc.core import (
CMC_ALPHA_SINGULARITY as _ALPHA_SINGULARITY,
)
from heterodyne.optimization.cmc.core import (
CMC_F0_DEGEN_THRESHOLD as _F0_DEGEN_THRESHOLD,
)
from heterodyne.optimization.cmc.core import fit_cmc_multi_phi
from heterodyne.optimization.nlsq import NLSQConfig, fit_nlsq_multi_phi
from heterodyne.optimization.nlsq.results import NLSQResult
from heterodyne.utils.logging import AnalysisSummaryLogger, get_logger, log_phase
if TYPE_CHECKING:
from heterodyne.config.manager import ConfigManager
from heterodyne.core.heterodyne_model import HeterodyneModel
from heterodyne.optimization.cmc.results import CMCResult
logger = get_logger(__name__)
def _closest_phi_index(data_phi_angles: np.ndarray, target: float) -> int:
"""Return the index of the data phi angle closest to *target* (degrees).
Uses circular distance so that 179° and -179° are treated as 2° apart
instead of 358° apart. A linear ``argmin(abs(d - t))`` would pick the
wrong slice near the ±180° boundary.
"""
normalized_data = (
(np.asarray(data_phi_angles, dtype=float) + 180.0) % 360.0
) - 180.0
normalized_target = ((float(target) + 180.0) % 360.0) - 180.0
delta = ((normalized_data - normalized_target + 180.0) % 360.0) - 180.0
return int(np.argmin(np.abs(delta)))
def _select_c2_for_phi_angles(
c2_data: np.ndarray,
phi_angles: list[float],
data_phi_angles: np.ndarray | None = None,
) -> np.ndarray:
"""Return a C2 stack aligned with selected phi angles."""
if c2_data.ndim != 3:
return c2_data
slices: list[np.ndarray] = []
for i, phi in enumerate(phi_angles):
if data_phi_angles is not None and len(data_phi_angles) == c2_data.shape[0]:
idx = _closest_phi_index(data_phi_angles, phi)
logger.info(
"Selected data slice %d (phi=%.2f°) for fitting phi=%.2f°",
idx,
float(data_phi_angles[idx]),
phi,
)
slices.append(c2_data[idx])
else:
slices.append(c2_data[i])
return np.stack(slices, axis=0)
def _combine_nlsq_results(results: list[NLSQResult]) -> NLSQResult:
"""Build a single aggregate result for disk output."""
if not results:
raise ValueError("Cannot combine empty NLSQ result list")
first = results[0]
def _stack_optional(attr: str) -> np.ndarray | None:
values = [getattr(result, attr) for result in results]
if any(value is None for value in values):
return None
return np.stack([np.asarray(value) for value in values], axis=0)
residual_values = [result.residuals for result in results]
residuals = (
np.concatenate([np.asarray(value).ravel() for value in residual_values])
if all(value is not None for value in residual_values)
else None
)
costs = [
float(result.final_cost) for result in results if result.final_cost is not None
]
final_cost = (
float(0.5 * np.sum(residuals**2))
if residuals is not None
else (float(np.sum(costs)) if costs else None)
)
chi2_values = [
float(result.reduced_chi_squared)
for result in results
if result.reduced_chi_squared is not None
]
metadata = {
"aggregate": True,
"n_angles": len(results),
"phi_angles": [result.metadata.get("phi_angle") for result in results],
"per_angle": [
{
"phi_angle": result.metadata.get("phi_angle"),
"success": result.success,
"message": result.message,
"final_cost": result.final_cost,
"reduced_chi_squared": result.reduced_chi_squared,
}
for result in results
],
}
return NLSQResult(
parameters=np.asarray(first.parameters),
parameter_names=list(first.parameter_names),
success=all(result.success for result in results),
message="multi-angle NLSQ complete",
uncertainties=first.uncertainties,
covariance=first.covariance,
final_cost=final_cost,
reduced_chi_squared=float(np.mean(chi2_values)) if chi2_values else None,
n_iterations=max((result.n_iterations for result in results), default=0),
n_function_evals=sum(result.n_function_evals for result in results),
convergence_reason=first.convergence_reason,
residuals=residuals,
jacobian=None,
fitted_correlation=_stack_optional("fitted_correlation"),
wall_time_seconds=first.metadata.get("wall_time_total")
or first.wall_time_seconds,
metadata=metadata,
)
[docs]
def run_nlsq(
model: HeterodyneModel,
c2_data: np.ndarray,
phi_angles: list[float],
config_manager: ConfigManager,
args: argparse.Namespace,
output_dir: Path,
summary: AnalysisSummaryLogger | None = None,
data_phi_angles: np.ndarray | None = None,
) -> list[NLSQResult]:
"""Run NLSQ analysis for all phi angles.
Args:
model: Configured HeterodyneModel.
c2_data: Correlation data (2D or 3D).
phi_angles: Phi angles to analyze.
config_manager: Configuration manager.
args: CLI arguments.
output_dir: Output directory for results.
summary: Optional summary logger for phase tracking.
Returns:
List of NLSQResult objects, one per phi angle.
"""
logger.info("Starting NLSQ analysis")
nlsq_config = NLSQConfig.from_dict(config_manager.nlsq_config)
nlsq_config.verbose = getattr(args, "verbose", 1)
c2_fit = _select_c2_for_phi_angles(c2_data, phi_angles, data_phi_angles)
with log_phase("nlsq_multi_phi", logger=logger, track_memory=True) as phase:
results = fit_nlsq_multi_phi(
model=model,
c2_data=c2_fit,
phi_angles=phi_angles,
config=nlsq_config,
)
logger.info(
"NLSQ multi-angle optimization completed in %.2fs for %d phi angles",
phase.duration,
len(phi_angles),
)
for i, (phi, result) in enumerate(zip(phi_angles, results, strict=True)):
result.metadata["phi_angle"] = phi
_warn_nlsq_bound_saturation(result)
if summary and result.reduced_chi_squared is not None:
summary.record_metric(
f"nlsq_chi2_phi{int(phi)}", result.reduced_chi_squared
)
# Post-fit RecoveryPlan diagnosis for failed angles (homodyne parity).
if not result.success:
try:
from heterodyne.optimization.nlsq.recovery import diagnose_failure
plan = diagnose_failure(result, nlsq_config)
logger.warning(
"NLSQ phi=%s° failed → recovery plan: %s (%s)",
phi,
plan.action.value,
plan.message,
)
result.metadata["recovery_plan"] = {
"action": plan.action.value,
"message": plan.message,
}
except (ValueError, AttributeError, RuntimeError) as exc:
logger.warning(
"diagnose_failure crashed on phi=%s° (%s); recovery plan unavailable",
phi,
exc,
)
summary_lines = format_nlsq_summary(result)
logger.info(
"NLSQ Results for phi=%s° (%d/%d)\n%s\n%s",
phi,
i + 1,
len(results),
"=" * 50,
summary_lines,
)
aggregate = _combine_nlsq_results(results)
# Per-phi batch statistics (homodyne parity)
if len(results) >= 2:
try:
from heterodyne.optimization.batch_statistics import (
compute_batch_statistics,
format_batch_report,
)
batch = compute_batch_statistics(results)
logger.info("NLSQ batch statistics:\n%s", format_batch_report(batch))
if summary is not None:
summary.record_metric(
"nlsq_batch_success_rate", batch.overall_success_rate
)
summary.record_metric("nlsq_batch_mean_chi2", batch.mean_chi2)
except (ValueError, AttributeError) as exc:
logger.warning("Batch statistics unavailable (%s); continuing", exc)
output_format = getattr(args, "output_format", "both")
if output_format in ("json", "both"):
saved_json = save_nlsq_json_files(aggregate, output_dir, prefix="nlsq")
for label, path in saved_json.items():
logger.info("Saved NLSQ %s: %s", label, path)
if output_format in ("npz", "both"):
npz_path = output_dir / "nlsq_data.npz"
save_nlsq_npz_file(aggregate, npz_path, c2_exp=c2_fit)
logger.info("Saved NLSQ data: %s", npz_path)
logger.info("NLSQ analysis complete")
return results
[docs]
def run_cmc(
model: HeterodyneModel,
c2_data: np.ndarray,
phi_angles: list[float],
config_manager: ConfigManager,
args: argparse.Namespace,
output_dir: Path,
nlsq_results: list[NLSQResult] | None = None,
summary: AnalysisSummaryLogger | None = None,
data_phi_angles: np.ndarray | None = None,
) -> CMCResult:
"""Run CMC Bayesian analysis for all phi angles.
Args:
model: Configured HeterodyneModel.
c2_data: Correlation data.
phi_angles: Phi angles to analyze.
config_manager: Configuration manager.
args: CLI arguments.
output_dir: Output directory.
nlsq_results: Optional NLSQ results for warm-starting.
summary: Optional summary logger for phase tracking.
Returns:
Joint multi-phi :class:`CMCResult` (homodyne parity). Reflects ONE
NUTS inference across all phi angles with shared 14 physics params
and per-angle scaling in ``mean_contrast`` / ``std_contrast`` /
``mean_offset`` / ``std_offset`` arrays of length ``n_phi``.
"""
logger.info("Starting CMC analysis (joint multi-phi, homodyne parity)")
if getattr(args, "num_samples", None) is not None:
logger.info("Overriding CMC num_samples from CLI: %s", args.num_samples)
config_manager.update_optimization_config(
"cmc", "num_samples", args.num_samples
)
if getattr(args, "num_chains", None) is not None:
logger.info("Overriding CMC num_chains from CLI: %s", args.num_chains)
config_manager.update_optimization_config("cmc", "num_chains", args.num_chains)
cmc_config = CMCConfig.from_dict(config_manager.cmc_config)
if cmc_config.backend_name == "jit":
cmc_config.backend_name = "pjit"
# ---- Stack per-angle c2 slices into (n_phi, N, N) for joint inference ----
c2_stack_list: list[np.ndarray] = []
nlsq_stack: list[NLSQResult] = []
has_nlsq = bool(nlsq_results)
fixed_overrides: dict[str, float] | None = None
if has_nlsq:
_varying_set = set(model.varying_names)
fixed_overrides = {
name: float(val)
for name, val in model.get_params_dict().items()
if name not in _varying_set
} or None
for i, phi in enumerate(phi_angles):
if c2_data.ndim == 3:
if data_phi_angles is not None and len(data_phi_angles) == c2_data.shape[0]:
idx = _closest_phi_index(data_phi_angles, phi)
c2_phi = c2_data[idx]
else:
c2_phi = c2_data[i]
else:
c2_phi = c2_data
c2_stack_list.append(np.asarray(c2_phi))
if has_nlsq and nlsq_results is not None and i < len(nlsq_results):
nr = nlsq_results[i]
if _validate_warmstart_quality(nr):
_log_warmstart_physical_params(nr)
else:
logger.warning(
"Warm-start quality below threshold for phi=%s°; using anyway", phi
)
nr = _clamp_warmstart_to_interior(nr, fixed_param_overrides=fixed_overrides)
_warn_degenerate_sample_regime(nr)
nlsq_stack.append(nr)
c2_stacked = np.stack(c2_stack_list, axis=0)
nlsq_for_engine: list[NLSQResult] | None = nlsq_stack if has_nlsq else None
n_points = int(c2_stacked.size)
logger.info(
"[CMC joint] n_phi=%d, total_points=%d, num_warmup=%d, num_samples=%d, "
"num_chains=%d",
len(phi_angles),
n_points,
cmc_config.num_warmup,
cmc_config.num_samples,
cmc_config.num_chains,
)
with log_phase("cmc_joint_multi_phi", logger=logger, track_memory=True) as phase:
result = fit_cmc_multi_phi(
model=model,
c2_data=c2_stacked,
phi_angles=list(phi_angles),
config=cmc_config,
nlsq_results=nlsq_for_engine,
)
result.metadata["phi_angles"] = [float(p) for p in phi_angles]
result.metadata["joint_multi_phi_runtime_s"] = phase.duration
logger.info(
"CMC joint multi-phi completed in %.2fs (n_phi=%d)",
phase.duration,
len(phi_angles),
)
if summary is not None:
summary.record_metric("cmc_n_samples", float(cmc_config.num_samples))
summary.record_metric("cmc_n_phi", float(len(phi_angles)))
logger.info(
"\n%s\nCMC Results (joint multi-phi)\n%s",
"=" * 50,
format_mcmc_summary(result),
)
output_format = getattr(args, "output_format", "both")
prefix = "cmc"
saved_paths = save_mcmc_results(result, output_dir, prefix=prefix)
if output_format == "json":
samples_path = saved_paths.get("samples")
if samples_path is not None:
samples_path.unlink(missing_ok=True)
elif output_format == "npz":
for key in ("summary", "diagnostics"):
json_path = saved_paths.get(key)
if json_path is not None:
json_path.unlink(missing_ok=True)
logger.info("Saved CMC results → %s (prefix=%s)", output_dir, prefix)
if _is_degenerate_cmc_result(result):
logger.error(
"[CMC] Joint multi-phi result is degenerate (no usable samples). "
"The warm-start parameters may be degenerate or the model is "
"unidentifiable for this dataset. Fixes: freeze degenerate params "
"in YAML (optimization.cmc.fixed_params), tighten NLSQ bounds, or "
"use allow_degenerate_warmstart: true to override."
)
logger.info("CMC analysis complete")
return result
def _is_degenerate_cmc_result(result: CMCResult) -> bool:
"""Return True when a CMC result is the all-shards-failed sentinel.
``_combine_shard_posteriors`` returns a CMCResult with no samples and
``convergence_passed=False`` when zero shards survived the
R-hat/ESS/no-samples gates. Detecting this lets the per-angle loop bail
out instead of running the next angle with the same broken warm-start.
"""
if getattr(result, "convergence_passed", True):
return False
samples = getattr(result, "samples", None)
if samples is None:
return True
# samples present but empty is also degenerate
try:
return all(np.asarray(s).size == 0 for s in samples.values())
except Exception:
return False
[docs]
def resolve_nlsq_warmstart(
args: argparse.Namespace,
output_dir: Path,
) -> NLSQResult | None:
"""Attempt to load previously saved NLSQ results for warm-starting CMC.
Args:
args: CLI arguments (``--nlsq-result PATH`` stored as
``args.nlsq_result``; legacy ``args.warmstart_path`` accepted).
output_dir: Default directory to search for NLSQ results.
Returns:
NLSQResult if found, None otherwise.
"""
# ``--nlsq-result`` is the documented user-facing flag (args_parser.py).
# ``warmstart_path`` is a legacy attribute name kept for programmatic
# callers; honour both so the CLI flag is not silently ignored.
warmstart_path = getattr(args, "nlsq_result", None) or getattr(
args, "warmstart_path", None
)
if warmstart_path is None:
# Try default location
default_path = output_dir / "nlsq_data.npz"
logger.debug(
"No warmstart path specified; checking default location %s", default_path
)
if default_path.exists():
warmstart_path = default_path
else:
logger.debug(
"No NLSQ warm-start available; CMC will use config initial values"
)
return None
try:
from heterodyne.io.nlsq_writers import load_nlsq_npz_file
warmstart_file = Path(warmstart_path)
if warmstart_file.is_dir():
warmstart_file = warmstart_file / "nlsq_data.npz"
result = load_nlsq_npz_file(warmstart_file)
logger.info("Loaded NLSQ warm-start from %s", warmstart_file)
return result
except (OSError, ValueError, KeyError) as exc:
logger.warning(
"Could not load NLSQ warm-start from %s: %s", warmstart_path, exc
)
return None
def _get_warmstart_reduced_chi2(result: NLSQResult) -> float | None:
"""Extract reduced chi-squared from NLSQ result.
Tries ``result.reduced_chi_squared`` first, then falls back to
``result.metadata["reduced_chi_squared"]``.
Args:
result: NLSQ result to inspect.
Returns:
Reduced chi-squared value, or ``None`` if unavailable.
"""
chi2 = getattr(result, "reduced_chi_squared", None)
if chi2 is not None:
return float(chi2)
return result.metadata.get("reduced_chi_squared") if result.metadata else None
def _validate_warmstart_quality(
result: NLSQResult,
chi2_threshold: float = 10.0,
) -> bool:
"""Check whether an NLSQ result is suitable for warm-starting CMC.
Validates convergence success, reduced chi-squared, and (when the
parameter registry is available) whether fitted values lie within
their declared bounds.
Args:
result: NLSQ result to validate.
chi2_threshold: Maximum acceptable reduced chi-squared.
Returns:
``True`` if quality is acceptable, ``False`` otherwise.
"""
ok = True
# --- convergence flag ---
if hasattr(result, "success") and not result.success:
logger.warning(
"Warm-start NLSQ did not converge: %s", getattr(result, "message", "")
)
ok = False
elif hasattr(result, "success"):
logger.debug("Warm-start convergence: OK")
# --- reduced chi-squared ---
chi2 = _get_warmstart_reduced_chi2(result)
if chi2 is not None:
if chi2 >= chi2_threshold:
logger.warning(
"Warm-start reduced chi² = %.3f exceeds threshold %.1f",
chi2,
chi2_threshold,
)
ok = False
else:
logger.debug(
"Warm-start reduced chi² = %.3f (threshold %.1f)", chi2, chi2_threshold
)
# --- parameter bounds check via registry ---
try:
from heterodyne.config.parameter_registry import ParameterRegistry
registry = ParameterRegistry()
params = result.params_dict
for name, value in params.items():
try:
info = registry[name]
except KeyError:
continue
if not (info.min_bound <= value <= info.max_bound):
logger.warning(
"Warm-start param %s = %.4e outside bounds [%.4e, %.4e]",
name,
value,
info.min_bound,
info.max_bound,
)
ok = False
except (ImportError, AttributeError, KeyError):
# Registry unavailable — skip bounds check
pass
if ok:
chi2_str = f"{chi2:.3f}" if chi2 is not None else "N/A"
logger.info(
"NLSQ warm-start accepted (reduced chi²=%s). Using as CMC initial values.",
chi2_str,
)
return ok
def _warn_nlsq_bound_saturation(result: NLSQResult) -> None:
"""Log a WARNING for each parameter whose uncertainty is zero or near-zero.
Zero uncertainty in the NLSQ covariance has two causes:
1. Parameter hit a bound → Jacobian column is zeroed by the box constraint.
2. Fraction model clipped to [0,1] everywhere → Jacobian w.r.t. f1/f2 = 0.
Both indicate a degenerate solution that will produce a pathological CMC
posterior (NUTS step-size collapses, chains freeze at initialization).
"""
if result.uncertainties is None or result.parameter_names is None:
return
try:
from heterodyne.config.parameter_registry import DEFAULT_REGISTRY
registry: Any = DEFAULT_REGISTRY
except ImportError:
registry = None
params = result.params_dict
saturated: list[str] = []
for name, unc in zip(result.parameter_names, result.uncertainties, strict=True):
if float(unc) < 1e-30:
val = params.get(name, float("nan"))
hint = ""
if registry is not None:
try:
info = registry[name]
if abs(val - info.min_bound) < 1e-10 * max(abs(info.min_bound), 1):
hint = " [AT LOWER BOUND]"
elif abs(val - info.max_bound) < 1e-10 * max(
abs(info.max_bound), 1
):
hint = " [AT UPPER BOUND]"
else:
hint = " [DEGENERATE JACOBIAN — check clipping]"
except KeyError:
pass
logger.warning(
"NLSQ bound saturation: %s = %.4g ± 0%s — "
"posterior will be unreliable; CMC chains may freeze",
name,
val,
hint,
)
saturated.append(name)
if saturated:
logger.warning(
"%d parameter(s) saturated at bounds or degenerate: %s. "
"Consider tightening bounds, adjusting initial values, or fixing "
"these parameters before running CMC.",
len(saturated),
saturated,
)
_BOUNDARY_INTERIOR_MARGIN = 5e-2 # fraction of bound range to keep away from walls
def _clamp_warmstart_to_interior(
result: NLSQResult,
fixed_param_overrides: dict[str, float] | None = None,
) -> NLSQResult:
"""CLI wrapper for ``heterodyne.optimization.cmc.priors.clamp_to_interior``.
P2-a / Phase 4 PR 4 Task 4.4: the implementation was absorbed from
:mod:`heterodyne.optimization.cmc.warmstart` into
:mod:`heterodyne.optimization.cmc.priors` (spec §7 Rule 1).
This wrapper exists only to preserve the private CLI-internal symbol
used by the rest of this module.
"""
from heterodyne.optimization.cmc.priors import clamp_to_interior
return clamp_to_interior(result, fixed_param_overrides)
_WARMSTART_LOG_PARAMS = ("D0_ref", "D0_sample", "v0", "alpha_ref", "alpha_sample")
# Import thresholds from the authoritative source so the two warning systems
# stay in sync when a threshold is tuned.
def _warn_degenerate_sample_regime(result: NLSQResult) -> None:
"""Warn when the NLSQ solution is in a degenerate sample-transport regime.
Two conditions — individually or combined — cause 100% CMC shard
bad_convergence and BFMI=0.000 (the het_bb97531f failure mode):
1. f0 < ``_F0_DEGEN_THRESHOLD``: sample fraction near zero makes
alpha_sample, D0_sample, D_offset_sample unidentifiable. Per-shard
posteriors are dominated by the tempered prior; NUTS must thermalize
from the warm-start across the full prior range with near-zero
likelihood gradient for the sample-transport group.
2. alpha_sample < ``_ALPHA_SINGULARITY``: J_sample(t) ∝ t^α has a
non-integrable singularity at t→0. NUTS step-size adaptation
collapses for the sample-transport group.
Calling this before CMC dispatch surfaces the problem immediately,
giving the user time to act (freeze parameters, increase warmup) before
investing hours of compute.
"""
params = result.params_dict
f0 = params.get("f0")
alpha_s = params.get("alpha_sample")
f0_degen = f0 is not None and float(f0) < _F0_DEGEN_THRESHOLD
alpha_sing = alpha_s is not None and float(alpha_s) < _ALPHA_SINGULARITY
if not f0_degen and not alpha_sing:
return
parts: list[str] = []
if f0_degen:
parts.append(
f"f0={float(f0):.4f} < {_F0_DEGEN_THRESHOLD} — sample fraction " # type: ignore[arg-type]
"near-zero; alpha_sample / D0_sample / D_offset_sample are "
"unidentifiable from data"
)
if alpha_sing:
parts.append(
f"alpha_sample={float(alpha_s):.3f} < {_ALPHA_SINGULARITY} — " # type: ignore[arg-type]
"J_sample ∝ t^α singularity at short lags; NUTS step-size collapses"
)
logger.warning(
"Degenerate NLSQ warm-start (het_bb97531f failure mode) — CMC likely "
"to fail for ALL shards:\n %s\n"
"Fixes:\n"
" • Freeze degenerate params in YAML → optimization.cmc.fixed_params: "
"{alpha_sample: %.3f, D0_sample: %.3g}\n"
" • Or increase num_warmup to ≥2000\n"
" • If f0 < 0.05, consider fixing f0=0 (reference-only model)",
";\n ".join(parts),
float(alpha_s) if alpha_s is not None else float("nan"),
float(params.get("D0_sample", float("nan"))),
)
def _log_warmstart_physical_params(result: NLSQResult) -> None:
"""Log key physical parameter values from an NLSQ warm-start result.
Logs at INFO level using scientific notation for easy inspection.
Missing parameters are silently skipped.
Args:
result: NLSQ result whose parameters are logged.
"""
params = result.params_dict
parts: list[str] = []
for name in _WARMSTART_LOG_PARAMS:
if name in params:
parts.append(f"{name}={params[name]:.2e}")
if parts:
logger.info("Warm-start params: %s", ", ".join(parts))