"""Reference-time reparameterization for heterodyne CMC.
Breaks banana-shaped posteriors for correlated power-law pairs
(D0/alpha) by sampling at a reference time t_ref where the product
D(t_ref) = D0 * t_ref^alpha is well-constrained by data.
Adapted from homodyne/optimization/cmc/reparameterization.py for
heterodyne's 3 power-law pairs --
D0_ref/alpha_ref, D0_sample/alpha_sample, and v0/beta.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
import jax.numpy as jnp
import numpy as np
from heterodyne.utils.logging import get_logger
logger = get_logger(__name__)
# Power-law pairs: (prefactor_name, exponent_name)
POWER_LAW_PAIRS: tuple[tuple[str, str], ...] = (
("D0_ref", "alpha_ref"),
("D0_sample", "alpha_sample"),
("v0", "beta"),
)
[docs]
@dataclass(frozen=True)
class ReparamConfig:
"""Configuration for reference-time reparameterization.
Attributes:
enable_d_ref: Reparameterize D0_ref/alpha_ref pair.
enable_d_sample: Reparameterize D0_sample/alpha_sample pair.
enable_v_ref: Reparameterize v0/beta pair.
t_ref: Reference time (geometric mean of dt and t_max).
"""
enable_d_ref: bool = True
enable_d_sample: bool = True
enable_v_ref: bool = True
t_ref: float = 1.0
@property
def enabled_pairs(self) -> list[tuple[str, str]]:
"""Return list of enabled (prefactor, exponent) pairs."""
flags = [self.enable_d_ref, self.enable_d_sample, self.enable_v_ref]
return [pair for pair, flag in zip(POWER_LAW_PAIRS, flags, strict=True) if flag]
[docs]
def is_reparameterized(self, name: str) -> bool:
"""Check if a parameter participates in reparameterization."""
for prefactor, exponent in self.enabled_pairs:
if name in (prefactor, exponent):
return True
return False
[docs]
def get_reparam_name(self, prefactor: str) -> str:
"""Get the reparameterized log-space name for a prefactor."""
return f"log_{prefactor}_at_tref"
[docs]
def compute_t_ref(
dt: float,
t_max: float,
fallback_value: float | None = None,
) -> float:
"""Compute reference time as geometric mean of dt and t_max.
t_ref = sqrt(dt * t_max)
This places t_ref in the middle of the logarithmic time range,
where the correlation function is most sensitive to the transport
parameters.
Args:
dt: Time step (minimum lag time).
t_max: Maximum lag time.
fallback_value: Value to use if dt or t_max are invalid.
Returns:
Reference time.
"""
if dt <= 0 or t_max <= 0:
if fallback_value is not None:
return fallback_value
raise ValueError(f"dt and t_max must be positive, got dt={dt}, t_max={t_max}")
t_ref = math.sqrt(dt * t_max)
if t_ref <= 0 or not math.isfinite(t_ref):
if fallback_value is not None:
return fallback_value
raise ValueError(f"Invalid t_ref={t_ref} from dt={dt}, t_max={t_max}")
return t_ref
[docs]
def reparam_to_physics_jax(
log_at_tref: jnp.ndarray,
alpha: jnp.ndarray,
t_ref: float,
) -> jnp.ndarray:
"""Back-transform reparameterized values to physics space (JAX).
A0 = exp(log_at_tref - alpha * log(t_ref))
Args:
log_at_tref: Log of the quantity at t_ref.
alpha: Power-law exponent.
t_ref: Reference time.
Returns:
A0 (prefactor in physics space).
"""
log_tref = jnp.log(jnp.float64(t_ref))
return jnp.exp(log_at_tref - alpha * log_tref)
# ---------------------------------------------------------------------------
# D_offset ratio reparameterization — homodyne CMC parity
# ---------------------------------------------------------------------------
#
# Homodyne's reparameterised model (model.py:863-1106) samples
# ``D_offset_ratio = D_offset / D_ref`` instead of ``D_offset`` directly:
#
# * It conditions the gradient with respect to ``D_offset`` on the
# active diffusion magnitude ``D_ref``, which is helpful when the
# offset is small relative to the prefactor.
# * It naturally supports negative offsets via a TruncatedNormal with
# ``low = -1 + ε`` so ``D_offset > -D_ref`` (i.e. ``J(t) > 0``).
#
# Heterodyne has *two* offsets (``D_offset_ref`` and ``D_offset_sample``)
# so two ratios. The helpers below provide the conversion in both
# directions so callers (and a future reparam path) can move between
# raw and ratio representations.
#: Minimum allowed ratio. Slightly above ``-1`` so the implied
#: ``D_offset > -D_ref`` keeps the diffusion rate strictly positive at
#: t_ref while preserving gradient information near the boundary.
D_OFFSET_RATIO_MIN: float = -0.99
[docs]
def d_offset_to_ratio(d_offset: float, d_ref: float) -> float:
"""Convert an absolute offset to the ratio representation.
``d_offset_ratio = d_offset / d_ref``. Returns ``0.0`` when
``d_ref`` is non-positive (degenerate channel).
"""
if d_ref <= 0.0:
return 0.0
return float(d_offset / d_ref)
[docs]
def ratio_to_d_offset(ratio: float, d_ref: float) -> float:
"""Reconstruct the absolute offset from the ratio representation.
``d_offset = ratio * d_ref``. Returns ``0.0`` when ``d_ref`` is
non-positive.
"""
if d_ref <= 0.0:
return 0.0
return float(ratio * d_ref)
[docs]
def heterodyne_offset_ratios_from_physics(
params: dict[str, float],
t_ref: float,
) -> dict[str, float]:
"""Compute ``D_offset_ratio`` for both reference and sample channels.
Evaluates each channel's diffusion magnitude at ``t_ref`` as
``D_ref(t_ref) = D0 * t_ref**alpha`` and returns the ratios
``D_offset_*_ratio = D_offset_* / D_ref(t_ref)``. Channels whose
``D_ref(t_ref)`` is non-positive yield a ``0.0`` ratio so callers
can fall back to direct sampling for that channel.
Args:
params: Mapping containing ``D0_ref``, ``alpha_ref``,
``D_offset_ref``, ``D0_sample``, ``alpha_sample``,
``D_offset_sample``.
t_ref: Reference time at which to evaluate ``D_ref``.
Returns:
``{"D_offset_ref_ratio": float, "D_offset_sample_ratio": float}``.
"""
out: dict[str, float] = {}
for prefix in ("ref", "sample"):
d0 = float(params.get(f"D0_{prefix}", 0.0))
alpha = float(params.get(f"alpha_{prefix}", 0.0))
d_offset = float(params.get(f"D_offset_{prefix}", 0.0))
d_ref = d0 * (t_ref**alpha) if t_ref > 0 else d0
out[f"D_offset_{prefix}_ratio"] = d_offset_to_ratio(d_offset, d_ref)
return out
[docs]
def heterodyne_physics_offsets_from_ratios(
ratios: dict[str, float],
physics: dict[str, float],
t_ref: float,
) -> dict[str, float]:
"""Inverse of :func:`heterodyne_offset_ratios_from_physics`.
Given ``D_offset_*_ratio`` values and the current physics-space
``(D0_*, alpha_*)`` parameters, returns the absolute
``D_offset_*`` values consistent with ``D_ref(t_ref)``.
"""
out: dict[str, float] = {}
for prefix in ("ref", "sample"):
d0 = float(physics.get(f"D0_{prefix}", 0.0))
alpha = float(physics.get(f"alpha_{prefix}", 0.0))
ratio = float(ratios.get(f"D_offset_{prefix}_ratio", 0.0))
d_ref = d0 * (t_ref**alpha) if t_ref > 0 else d0
out[f"D_offset_{prefix}"] = ratio_to_d_offset(ratio, d_ref)
return out