Source code for heterodyne.optimization.cmc.scaling

"""Smooth bounded parameter scaling for heterodyne CMC.

Replaces jnp.clip() (zero gradient at bounds) with tanh-based smooth
bounding that is differentiable everywhere, allowing NUTS to adapt its
mass matrix near parameter boundaries.

Adapted from homodyne/optimization/cmc/scaling.py.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

import jax.numpy as jnp
import numpy as np

from heterodyne.utils.logging import get_logger

if TYPE_CHECKING:
    from heterodyne.config.parameter_space import ParameterSpace

logger = get_logger(__name__)


[docs] @dataclass class ParameterScaling: """Scaling specification for a single parameter. Defines the mapping between z-space (standard normal for MCMC) and original physics space with smooth bounding. Attributes: name: Parameter name. center: NLSQ best-fit value (center of prior). scale: Prior width (NLSQ uncertainty × width_factor). low: Lower bound in physics space. high: Upper bound in physics space. """ name: str center: float scale: float low: float high: float
[docs] def to_normalized(self, value: float | jnp.ndarray) -> float | jnp.ndarray: """Transform from physics space to z-space (normalized). z = (value - center) / scale Args: value: Physics-space value. Returns: Normalized z-space value. """ return (value - self.center) / self.scale
[docs] def to_original(self, z_value: jnp.ndarray) -> jnp.ndarray: """Transform from z-space to bounded original (physics) space. raw = center + scale * z result = smooth_bound(raw, low, high) Args: z_value: Normalized z-space value. Returns: Bounded physics-space value. """ raw = self.center + self.scale * z_value return smooth_bound(raw, self.low, self.high)
[docs] def smooth_bound( raw: jnp.ndarray, low: float, high: float, ) -> jnp.ndarray: """Smooth bounding using tanh transform. Maps (-inf, +inf) → (low, high) via: mid + half * tanh((raw - mid) / half) This is differentiable everywhere, unlike jnp.clip() which has zero gradient at bounds and kills NUTS adaptation. Args: raw: Unbounded input value. low: Lower bound. high: Upper bound. Returns: Bounded value in (low, high). """ mid = jnp.float64((low + high) / 2.0) half = jnp.float64((high - low) / 2.0) # Guard degenerate bounds (low == high) to avoid 0/0 → NaN return jnp.where(half > 0, mid + half * jnp.tanh((raw - mid) / half), mid)
[docs] def smooth_bound_inverse( value: float, low: float, high: float, ) -> float: """Inverse of smooth_bound for initialization. Recovers the raw (unbounded) value from a bounded value: raw = mid + half * arctanh((value - mid) / half) Args: value: Bounded value in (low, high). low: Lower bound. high: Upper bound. Returns: Unbounded raw value. """ mid = (low + high) / 2.0 half = (high - low) / 2.0 # Clamp to avoid arctanh(±1) = ±inf normalized = (value - mid) / half normalized = float(np.clip(normalized, -0.999, 0.999)) return mid + half * float(np.arctanh(normalized))
[docs] def compute_scaling_factors( space: ParameterSpace, nlsq_values: dict[str, float] | None = None, nlsq_uncertainties: dict[str, float] | None = None, width_factor: float = 2.0, ) -> dict[str, ParameterScaling]: """Build ParameterScaling for each varying parameter. Uses NLSQ values as centers and NLSQ uncertainties × width_factor as scale. Falls back to bounds midpoint and range/6 when NLSQ results are unavailable. Args: space: Parameter space with bounds and varying flags. nlsq_values: NLSQ best-fit values by name. nlsq_uncertainties: NLSQ uncertainties by name. width_factor: Multiplier on NLSQ uncertainty for prior width. Returns: Dict mapping parameter name to ParameterScaling. """ scalings: dict[str, ParameterScaling] = {} for name in space.varying_names: low, high = space.bounds[name] # Center: NLSQ value or bounds midpoint if nlsq_values is not None and name in nlsq_values: center = nlsq_values[name] else: center = (low + high) / 2.0 # Scale: NLSQ uncertainty × width_factor or bounds range / 6 if ( nlsq_uncertainties is not None and name in nlsq_uncertainties and nlsq_uncertainties[name] > 0 ): scale = nlsq_uncertainties[name] * width_factor else: scale = (high - low) / 6.0 # Ensure minimum scale to avoid division by zero scale = max(scale, 1e-10) scalings[name] = ParameterScaling( name=name, center=center, scale=scale, low=low, high=high, ) return scalings
[docs] def log_scaling_factors(scalings: dict[str, ParameterScaling]) -> None: """Log all scaling factors for debugging. Emits a header at INFO level, then per-parameter details at DEBUG. Args: scalings: Mapping of parameter name to its scaling specification. """ logger.info("Scaling factors for %d parameters:", len(scalings)) for name, s in scalings.items(): logger.debug( " %s: center=%s, scale=%s, bounds=[%s, %s]", name, f"{s.center:.4e}", f"{s.scale:.4e}", f"{s.low:.4e}", f"{s.high:.4e}", )
[docs] def transform_initial_values_to_z( initial_values: dict[str, float], scalings: dict[str, ParameterScaling], ) -> dict[str, float]: """Transform initial values from physics space to z-space. Only transforms parameters present in both *initial_values* and *scalings*. Args: initial_values: Physics-space values keyed by parameter name. scalings: Scaling specifications keyed by parameter name. Returns: Dict with keys ``{name}_z`` mapped to normalized z-space values. """ z_values: dict[str, float] = {} for name, value in initial_values.items(): if name in scalings: z_values[f"{name}_z"] = float(scalings[name].to_normalized(value)) return z_values
[docs] def transform_samples_from_z( samples: dict[str, jnp.ndarray], scalings: dict[str, ParameterScaling], ) -> dict[str, jnp.ndarray]: """Transform MCMC samples from z-space back to physics space. Input keys must end with ``"_z"``; the suffix is stripped to recover the original parameter name. Only parameters with a matching entry in *scalings* are transformed. Args: samples: Z-space sample arrays keyed by ``{name}_z``. scalings: Scaling specifications keyed by parameter name. Returns: Dict with original parameter names mapped to physics-space arrays. """ physics: dict[str, jnp.ndarray] = {} for key, z_value in samples.items(): if not key.endswith("_z"): continue name = key[:-2] if name in scalings: physics[name] = scalings[name].to_original(z_value) return physics