Source code for heterodyne.optimization.cmc.model

"""NumPyro model definition for heterodyne Bayesian inference."""

from __future__ import annotations

import math
from typing import TYPE_CHECKING

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist

from heterodyne.config.parameter_names import ALL_PARAM_NAMES, PARAM_INDICES
from heterodyne.config.parameter_registry import DEFAULT_REGISTRY
from heterodyne.core.jax_backend import (
    compute_c2_heterodyne,
    compute_c2_heterodyne_pooled,
)
from heterodyne.core.physics_cmc import ShardGrid, compute_c2_elementwise
from heterodyne.optimization.cmc.reparameterization import (
    ReparamConfig,
    reparam_to_physics_jax,
)
from heterodyne.optimization.cmc.scaling import ParameterScaling, smooth_bound
from heterodyne.utils.logging import get_logger

if TYPE_CHECKING:
    from collections.abc import Callable

    from heterodyne.config.parameter_space import ParameterSpace
    from heterodyne.optimization.nlsq.results import NLSQResult

logger = get_logger(__name__)


def _heterodyne_sample_shared_physics(space: ParameterSpace) -> jnp.ndarray:
    """Sample the 14 heterodyne physics parameters (shared across angles).

    Returns the full ``(14,)`` parameter vector with sampled values for
    ``space.varying_names`` and fixed values from ``space`` for the rest.
    ``contrast`` and ``offset`` are skipped here — caller supplies them
    via ``contrast_arr`` / ``offset_arr`` (per-angle).
    """
    varying_names = space.varying_names
    fixed_values = space.get_initial_array()
    params = jnp.asarray(fixed_values)
    for i, name in enumerate(ALL_PARAM_NAMES):
        if name in ("contrast", "offset"):
            continue
        if name in varying_names:
            prior = space.priors[name]
            param = numpyro.sample(name, prior.to_numpyro(name))
            params = params.at[i].set(param)
    return params


def _heterodyne_pooled_likelihood(
    params: jnp.ndarray,
    contrast_arr: jnp.ndarray,
    offset_arr: jnp.ndarray,
    data: jnp.ndarray,
    t: jnp.ndarray,
    q: float,
    dt: float,
    phi_unique: jnp.ndarray,
    phi_indices: jnp.ndarray,
    i1_indices: jnp.ndarray,
    i2_indices: jnp.ndarray,
    noise_scale: float,
    num_shards: int,
) -> None:
    """Shared physics → boundary mask → likelihood (joint multi-phi).

    Phase 4 of the joint multi-phi refactor: calls
    :func:`compute_c2_heterodyne_pooled` directly to obtain the
    ``(n_total,)`` c2 vector at the pooled ``(phi, t1, t2)`` points without
    ever materializing the ``(n_phi, N, N)`` stack that the older
    vmap+gather path required. All 4 joint variants (scaled, constant,
    averaged, constant_averaged) delegate the last stages to this helper.
    """
    c2_per_point = compute_c2_heterodyne_pooled(
        params,
        t,
        q,
        dt,
        i1_indices,
        i2_indices,
        phi_indices,
        phi_unique,
        contrast_arr,
        offset_arr,
    )

    n_nan = jnp.sum(~jnp.isfinite(c2_per_point))
    numpyro.deterministic("n_numerical_issues", n_nan)

    sigma_scale = float(noise_scale) * 1.5 * math.sqrt(num_shards)
    sigma = numpyro.sample("sigma", dist.HalfNormal(scale=sigma_scale))

    boundary_mask = (i1_indices > 0) & (i2_indices > 0)
    with numpyro.handlers.mask(mask=boundary_mask):
        numpyro.sample("obs", dist.Normal(c2_per_point, sigma), obs=data)


[docs] def xpcs_model_heterodyne_scaled( data: jnp.ndarray, t: jnp.ndarray, q: float, dt: float, phi_unique: jnp.ndarray, phi_indices: jnp.ndarray, i1_indices: jnp.ndarray, i2_indices: jnp.ndarray, noise_scale: float, space: ParameterSpace, num_shards: int = 1, ) -> None: """Joint multi-phi heterodyne CMC model (homodyne parity). Mirrors ``homodyne.optimization.cmc.model.xpcs_model_scaled``: ONE NUTS pass over pooled multi-phi data with shared 14 physics parameters and per-angle sampled contrast / offset. The likelihood site evaluates the Normal log-prob at every pooled point in a single ``numpyro.sample`` call, gather-by-phi-index. Parameters ---------- data: Pooled C2 values, shape ``(n_total,)`` after diagonal filtering. t: Unique time grid, shape ``(N,)``. Used by ``compute_c2_heterodyne``. q, dt: Physics scalars. phi_unique: Sorted unique phi angles, shape ``(n_phi,)``. phi_indices: Per-point index into ``phi_unique``, shape ``(n_total,)``. i1_indices, i2_indices: Per-point indices into ``t`` for the two time coordinates, shape ``(n_total,)`` each. Pre-computed via ``np.searchsorted(t, t1)`` / ``np.searchsorted(t, t2)``. noise_scale: Data-driven sigma prior centre (homodyne-parity ``HalfNormal`` scale = ``noise_scale * 1.5 * sqrt(num_shards)``). space: Parameter space holding priors and initial values for the 14 physics parameters + 2 scaling. num_shards: Shard count for CMC sigma-prior tempering (Scott et al. 2016). Default ``1`` (no tempering). """ n_phi = int(phi_unique.shape[0]) contrast_prior = space.priors["contrast"] offset_prior = space.priors["offset"] contrast_list = [ numpyro.sample(f"contrast_{i}", contrast_prior.to_numpyro(f"contrast_{i}")) for i in range(n_phi) ] offset_list = [ numpyro.sample(f"offset_{i}", offset_prior.to_numpyro(f"offset_{i}")) for i in range(n_phi) ] contrast_arr = jnp.stack(contrast_list) offset_arr = jnp.stack(offset_list) params = _heterodyne_sample_shared_physics(space) _heterodyne_pooled_likelihood( params, contrast_arr, offset_arr, data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, num_shards, )
[docs] def xpcs_model_heterodyne_constant( data: jnp.ndarray, t: jnp.ndarray, q: float, dt: float, phi_unique: jnp.ndarray, phi_indices: jnp.ndarray, i1_indices: jnp.ndarray, i2_indices: jnp.ndarray, noise_scale: float, space: ParameterSpace, fixed_contrast: jnp.ndarray, fixed_offset: jnp.ndarray, num_shards: int = 1, ) -> None: """Joint multi-phi CMC model with FIXED per-angle scaling. Mirrors ``homodyne.optimization.cmc.model.xpcs_model_constant``. Per-angle ``contrast`` and ``offset`` are passed in as arrays (length ``n_phi``, typically derived from quantile estimation on the raw data) and NOT sampled — only the 14 physics params + sigma are sampled. """ contrast_arr = jnp.asarray(fixed_contrast, dtype=jnp.float64) offset_arr = jnp.asarray(fixed_offset, dtype=jnp.float64) params = _heterodyne_sample_shared_physics(space) _heterodyne_pooled_likelihood( params, contrast_arr, offset_arr, data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, num_shards, )
[docs] def xpcs_model_heterodyne_averaged( data: jnp.ndarray, t: jnp.ndarray, q: float, dt: float, phi_unique: jnp.ndarray, phi_indices: jnp.ndarray, i1_indices: jnp.ndarray, i2_indices: jnp.ndarray, noise_scale: float, space: ParameterSpace, num_shards: int = 1, ) -> None: """Joint multi-phi CMC model with SAMPLED averaged (single) scaling. Mirrors ``homodyne.optimization.cmc.model.xpcs_model_averaged``. A single ``contrast`` and a single ``offset`` are sampled and broadcast across all ``n_phi`` angles (cf. heterodyne ``per_angle_mode="auto"`` when the auto-resolver promotes to averaged scaling). """ n_phi = int(phi_unique.shape[0]) contrast = numpyro.sample( "contrast", space.priors["contrast"].to_numpyro("contrast") ) offset = numpyro.sample("offset", space.priors["offset"].to_numpyro("offset")) contrast_arr = jnp.full((n_phi,), contrast) offset_arr = jnp.full((n_phi,), offset) params = _heterodyne_sample_shared_physics(space) _heterodyne_pooled_likelihood( params, contrast_arr, offset_arr, data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, num_shards, )
[docs] def xpcs_model_heterodyne_constant_averaged( data: jnp.ndarray, t: jnp.ndarray, q: float, dt: float, phi_unique: jnp.ndarray, phi_indices: jnp.ndarray, i1_indices: jnp.ndarray, i2_indices: jnp.ndarray, noise_scale: float, space: ParameterSpace, fixed_contrast: float, fixed_offset: float, num_shards: int = 1, ) -> None: """Joint multi-phi CMC model with FIXED averaged (single) scaling. Mirrors ``homodyne.optimization.cmc.model.xpcs_model_constant_averaged``. A single ``contrast`` and ``offset`` (typically the mean of the NLSQ per-angle estimates) are broadcast across all ``n_phi`` angles. No scaling parameters are sampled — only the 14 physics params + sigma. """ n_phi = int(phi_unique.shape[0]) contrast_arr = jnp.full((n_phi,), float(fixed_contrast)) offset_arr = jnp.full((n_phi,), float(fixed_offset)) params = _heterodyne_sample_shared_physics(space) _heterodyne_pooled_likelihood( params, contrast_arr, offset_arr, data, t, q, dt, phi_unique, phi_indices, i1_indices, i2_indices, noise_scale, num_shards, )
[docs] def get_heterodyne_pooled_model_for_mode( per_angle_mode: str, *, data: jnp.ndarray, t: jnp.ndarray, q: float, dt: float, phi_unique: jnp.ndarray, phi_indices: jnp.ndarray, i1_indices: jnp.ndarray, i2_indices: jnp.ndarray, noise_scale: float, space: ParameterSpace, fixed_contrast: np.ndarray | jnp.ndarray | float | None = None, fixed_offset: np.ndarray | jnp.ndarray | float | None = None, num_shards: int = 1, ) -> Callable[[], None]: """Dispatch to the joint multi-phi model variant for ``per_angle_mode``. Mirrors ``homodyne.optimization.cmc.model.get_xpcs_model`` at the pooled-data layer. Returns a zero-arg callable suitable for passing to ``NUTS``. Modes: - ``"individual"`` / ``"scaled"`` → :func:`xpcs_model_heterodyne_scaled` (per-angle sampled contrast/offset). - ``"constant"`` → :func:`xpcs_model_heterodyne_constant`. Requires ``fixed_contrast`` and ``fixed_offset`` as length-``n_phi`` arrays. - ``"auto"`` / ``"averaged"`` → :func:`xpcs_model_heterodyne_averaged` (single sampled averaged contrast/offset). - ``"constant_averaged"`` → :func:`xpcs_model_heterodyne_constant_averaged`. Requires scalar ``fixed_contrast`` and ``fixed_offset``. """ if per_angle_mode in ("individual", "scaled"): return lambda: xpcs_model_heterodyne_scaled( data=data, t=t, q=q, dt=dt, phi_unique=phi_unique, phi_indices=phi_indices, i1_indices=i1_indices, i2_indices=i2_indices, noise_scale=noise_scale, space=space, num_shards=num_shards, ) if per_angle_mode == "constant": if fixed_contrast is None or fixed_offset is None: raise ValueError( "per_angle_mode='constant' requires fixed_contrast and " "fixed_offset arrays of length n_phi." ) fc = jnp.asarray(fixed_contrast, dtype=jnp.float64) fo = jnp.asarray(fixed_offset, dtype=jnp.float64) return lambda: xpcs_model_heterodyne_constant( data=data, t=t, q=q, dt=dt, phi_unique=phi_unique, phi_indices=phi_indices, i1_indices=i1_indices, i2_indices=i2_indices, noise_scale=noise_scale, space=space, fixed_contrast=fc, fixed_offset=fo, num_shards=num_shards, ) if per_angle_mode in ("auto", "averaged"): return lambda: xpcs_model_heterodyne_averaged( data=data, t=t, q=q, dt=dt, phi_unique=phi_unique, phi_indices=phi_indices, i1_indices=i1_indices, i2_indices=i2_indices, noise_scale=noise_scale, space=space, num_shards=num_shards, ) if per_angle_mode == "constant_averaged": if fixed_contrast is None or fixed_offset is None: raise ValueError( "per_angle_mode='constant_averaged' requires scalar " "fixed_contrast and fixed_offset." ) return lambda: xpcs_model_heterodyne_constant_averaged( data=data, t=t, q=q, dt=dt, phi_unique=phi_unique, phi_indices=phi_indices, i1_indices=i1_indices, i2_indices=i2_indices, noise_scale=noise_scale, space=space, fixed_contrast=float(np.asarray(fixed_contrast).mean()), fixed_offset=float(np.asarray(fixed_offset).mean()), num_shards=num_shards, ) raise ValueError( f"Unknown per_angle_mode {per_angle_mode!r}; expected one of " "{'individual','scaled','constant','auto','averaged'," "'constant_averaged'}" )
def _likelihood_boundary_mask( c2_data: jnp.ndarray, shard_grid: ShardGrid | None ) -> jnp.ndarray: """Boolean mask: True where (t1, t2) is NOT on the t=0 row or column. Mirrors the NLSQ-side mask in ``heterodyne.core.jax_backend`` so the Bayesian likelihood also honors the t=0 boundary contract — t=0 is loaded and plotted but excluded from the likelihood. The mask works for both the meshgrid path (``c2_data`` shape ``(N, N)``) and the element-wise sharded path (``c2_data`` shape ``(n_pairs,)``). """ if shard_grid is not None: return (shard_grid.idx1 > 0) & (shard_grid.idx2 > 0) n_time = c2_data.shape[-1] indices = jnp.arange(n_time) return (indices[:, None] > 0) & (indices[None, :] > 0)
[docs] def get_heterodyne_model( t: jnp.ndarray, q: float, dt: float, phi_angle: float, c2_data: jnp.ndarray, noise_scale: float, space: ParameterSpace, contrast: float = 1.0, offset: float = 1.0, shard_grid: ShardGrid | None = None, priors_override: dict | None = None, num_shards: int = 1, ): """Create NumPyro model for heterodyne correlation fitting. Sigma is sampled as a posterior variable via ``HalfNormal(noise_scale * 1.5 * sqrt(num_shards))``, matching the homodyne parity convention so the posterior captures noise uncertainty. Args: t: Time array q: Wavevector dt: Time step phi_angle: Detector phi angle c2_data: Observed correlation data — shape ``(N, N)`` for meshgrid path, or ``(n_pairs,)`` for element-wise path. noise_scale: Data-driven prior center for the measurement-uncertainty ``sigma`` posterior. Typically the mean / RMS of an external estimate from :func:`estimate_sigma`. space: Parameter space with priors contrast: Speckle contrast (beta), default 1.0 offset: Baseline offset, default 1.0 shard_grid: Optional pre-computed ShardGrid. When provided, uses the memory-efficient element-wise path (no N×N allocation). ``c2_data`` must then be flattened to match the shard grid's paired indices. priors_override: Optional dictionary mapping parameter names to NumPyro distributions. When provided, overrides the default ``space.priors[name]`` for any matching parameter name. Used by ``fit_cmc_sharded`` to inject tempered priors. num_shards: Number of CMC shards for sigma prior tempering. Widens the ``HalfNormal`` scale by ``sqrt(num_shards)`` so that the product across shards stays equivalent to the unsharded prior. Defaults to ``1`` (no tempering). Returns: NumPyro model function """ # Pre-compute indices and masks varying_names = space.varying_names fixed_values = space.get_initial_array() prior_scale = math.sqrt(num_shards) sigma_scale = float(noise_scale) * 1.5 * prior_scale def model(): """NumPyro model for heterodyne correlation.""" # Sample varying parameters and scatter into fixed array # Using .at[].set() instead of jnp.array([...]) to avoid # tracing issues with mixed tracer/concrete values. params = jnp.asarray(fixed_values) for i, name in enumerate(ALL_PARAM_NAMES): if name in varying_names: if priors_override is not None and name in priors_override: param = numpyro.sample(name, priors_override[name]) else: prior = space.priors[name] param = numpyro.sample(name, prior.to_numpyro(name)) params = params.at[i].set(param) # Compute model prediction — dispatch to appropriate path if shard_grid is not None: c2_model = compute_c2_elementwise( params, shard_grid, q, dt, phi_angle, contrast, offset, ) else: c2_model = compute_c2_heterodyne( params, t, q, dt, phi_angle, contrast, offset, ) # Track NaN/inf so callers can flag pathological shards. n_nan = jnp.sum(~jnp.isfinite(c2_model)) numpyro.deterministic("n_numerical_issues", n_nan) # Sample sigma with prior tempered for CMC sharding (parity with homodyne). sigma = numpyro.sample("sigma", dist.HalfNormal(scale=sigma_scale)) # Likelihood (t=0 boundary excluded via mask, per the heterodyne # contract: load and plot full N×N, exclude t=0 row/col from # fitting only). with numpyro.handlers.mask(mask=_likelihood_boundary_mask(c2_data, shard_grid)): numpyro.sample( "obs", dist.Normal(c2_model, sigma), obs=c2_data, ) return model
[docs] def get_heterodyne_model_reparam( t: jnp.ndarray, q: float, dt: float, phi_angle: float, c2_data: jnp.ndarray, noise_scale: float, space: ParameterSpace, nlsq_params: jnp.ndarray | None = None, reparam_config: ReparamConfig | None = None, scalings: dict[str, ParameterScaling] | None = None, contrast: float = 1.0, offset: float = 1.0, shard_grid: ShardGrid | None = None, num_shards: int = 1, ): """Create NumPyro model with reparameterization for better sampling. When NLSQ result and reparameterization config are provided, uses: 1. Reference-time reparameterization for power-law pairs 2. Smooth bounded transforms (tanh) instead of jnp.clip() 3. NLSQ-informed priors with delta-method uncertainty propagation Falls back to the original clip-based behavior when the new infrastructure is not provided (backward compatibility). Sigma is sampled internally via ``HalfNormal(noise_scale * 1.5 * sqrt(num_shards))`` to match homodyne CMC parity. Args: t: Time array q: Wavevector dt: Time step phi_angle: Detector phi angle c2_data: Observed correlation data noise_scale: Data-driven prior center for sampled ``sigma``. space: Parameter space nlsq_params: Optional NLSQ fitted values for centering (legacy path) reparam_config: Reparameterization config (enables new path) scalings: Pre-computed ParameterScaling per reparam-space param num_shards: CMC shard count for sigma prior tempering. Default ``1``. Returns: NumPyro model function """ varying_names = space.varying_names fixed_values = space.get_initial_array() prior_scale = math.sqrt(num_shards) sigma_scale = float(noise_scale) * 1.5 * prior_scale # --- New reparameterized path --- if reparam_config is not None and scalings is not None: return _build_reparam_model( t=t, q=q, dt=dt, phi_angle=phi_angle, c2_data=c2_data, sigma_scale=sigma_scale, space=space, fixed_values=jnp.asarray(fixed_values), varying_names=varying_names, reparam_config=reparam_config, scalings=scalings, contrast=contrast, offset=offset, shard_grid=shard_grid, ) # --- Legacy clip-based path (backward compatibility) --- if nlsq_params is not None: prior_centers = { name: float(nlsq_params[ALL_PARAM_NAMES.index(name)]) for name in varying_names } else: prior_centers = {name: space.values[name] for name in varying_names} def model(): """NumPyro model with centered parameterization (legacy).""" # Using .at[].set() instead of jnp.array([...]) to avoid # tracing issues with mixed tracer/concrete values. params = jnp.asarray(fixed_values) for i, name in enumerate(ALL_PARAM_NAMES): if name in varying_names: center = prior_centers[name] bounds = space.bounds[name] scale = (bounds[1] - bounds[0]) / 6.0 raw = numpyro.sample(f"{name}_raw", dist.Normal(center, scale)) # NOTE: jnp.clip has discontinuous gradient at bounds. # The reparameterized path uses smooth_bound() instead. param = jnp.clip(raw, bounds[0], bounds[1]) numpyro.deterministic(name, param) params = params.at[i].set(param) if shard_grid is not None: c2_model = compute_c2_elementwise( params, shard_grid, q, dt, phi_angle, contrast, offset, ) else: c2_model = compute_c2_heterodyne( params, t, q, dt, phi_angle, contrast, offset, ) n_nan = jnp.sum(~jnp.isfinite(c2_model)) numpyro.deterministic("n_numerical_issues", n_nan) sigma = numpyro.sample("sigma", dist.HalfNormal(scale=sigma_scale)) with numpyro.handlers.mask(mask=_likelihood_boundary_mask(c2_data, shard_grid)): numpyro.sample("obs", dist.Normal(c2_model, sigma), obs=c2_data) return model
def _build_reparam_model( *, t: jnp.ndarray, q: float, dt: float, phi_angle: float, c2_data: jnp.ndarray, sigma_scale: float, space: ParameterSpace, fixed_values: jnp.ndarray, varying_names: list[str], reparam_config: ReparamConfig, scalings: dict[str, ParameterScaling], contrast: float = 1.0, offset: float = 1.0, shard_grid: ShardGrid | None = None, ): """Build NumPyro model using reference-time reparameterization + smooth bounds.""" # Pre-compute which sampling-space names map to which physics params # Build lookup: sampling_name -> (scaling, is_reparam_log, pair_info) enabled_pairs = reparam_config.enabled_pairs t_ref = reparam_config.t_ref # Map prefactor names to their reparam log-space names prefactor_to_log: dict[str, str] = {} log_to_prefactor: dict[str, str] = {} log_to_exponent: dict[str, str] = {} for prefactor, exponent in enabled_pairs: if prefactor in varying_names and exponent in varying_names: log_name = reparam_config.get_reparam_name(prefactor) prefactor_to_log[prefactor] = log_name log_to_prefactor[log_name] = prefactor log_to_exponent[log_name] = exponent # Determine sampling-space parameter names (in order for the model) sampling_names: list[str] = [] for name in varying_names: if name in prefactor_to_log: sampling_names.append(prefactor_to_log[name]) else: sampling_names.append(name) def model(): """NumPyro model with reference-time reparam + smooth bounds.""" # Sample in z-space, then transform sampled_values: dict[str, jnp.ndarray] = {} for sname in sampling_names: if sname not in scalings: continue sc = scalings[sname] # Sample z ~ N(0, 1) z = numpyro.sample(f"{sname}_z", dist.Normal(0.0, 1.0)) # Transform: raw = center + scale * z, then smooth bound bounded = sc.to_original(z) sampled_values[sname] = bounded # Back-transform reparameterized pairs to physics space physics_values: dict[str, jnp.ndarray] = {} for sname, value in sampled_values.items(): if sname in log_to_prefactor: # This is a log_X_at_tref — back-transform to prefactor prefactor = log_to_prefactor[sname] exponent = log_to_exponent[sname] alpha = sampled_values[exponent] a0 = reparam_to_physics_jax(value, alpha, t_ref) physics_values[prefactor] = a0 # Register physics-space prefactor as deterministic numpyro.deterministic(prefactor, a0) # Register the log value too for diagnostics numpyro.deterministic(sname, value) elif sname not in physics_values: # Direct parameter (exponent or non-reparameterized) physics_values[sname] = value numpyro.deterministic(sname, value) # Assemble full parameter array using scatter (handles MCMC batch dims). # squeeze() removes any singleton batch dimensions from chain vectorization # so that values match the scalar elements of fixed_values. params = jnp.asarray(fixed_values) for name, value in physics_values.items(): params = params.at[PARAM_INDICES[name]].set(jnp.squeeze(value)) if shard_grid is not None: c2_model = compute_c2_elementwise( params, shard_grid, q, dt, phi_angle, contrast, offset, ) else: c2_model = compute_c2_heterodyne( params, t, q, dt, phi_angle, contrast, offset, ) n_nan = jnp.sum(~jnp.isfinite(c2_model)) numpyro.deterministic("n_numerical_issues", n_nan) sigma = numpyro.sample("sigma", dist.HalfNormal(scale=sigma_scale)) with numpyro.handlers.mask(mask=_likelihood_boundary_mask(c2_data, shard_grid)): numpyro.sample("obs", dist.Normal(c2_model, sigma), obs=c2_data) return model # --------------------------------------------------------------------------- # Per-angle mode models # ---------------------------------------------------------------------------
[docs] def get_heterodyne_model_constant( t: jnp.ndarray, q: float, dt: float, phi_angle: float, c2_data: jnp.ndarray, noise_scale: float, space: ParameterSpace, fixed_contrast: jnp.ndarray, fixed_offset: jnp.ndarray, shard_grid: ShardGrid | None = None, num_shards: int = 1, ): """Create NumPyro model with FIXED (pre-computed) per-angle scaling. Contrast and offset are not sampled — they are provided as fixed arrays from a preceding NLSQ or preprocessing step. Suitable for ``per_angle_mode="constant"``, where each angle has its own fixed scaling but the physical parameters are shared. Sigma is sampled internally via ``HalfNormal(noise_scale * 1.5 * sqrt(num_shards))`` for homodyne CMC parity. Args: t: Time array, shape ``(n_t,)``. q: Wavevector magnitude (Å⁻¹). dt: Lag-time step (s). phi_angle: Detector phi angle for this shard (degrees). c2_data: Observed correlation data, shape ``(n_t,)`` or ``(n_phi, n_t)``. noise_scale: Data-driven prior center for sampled ``sigma``. space: Parameter space carrying priors and fixed values. fixed_contrast: Speckle contrast per angle, shape ``(n_phi,)`` or scalar. fixed_offset: Baseline offset per angle, shape ``(n_phi,)`` or scalar. num_shards: CMC shard count for sigma prior tempering. Default ``1``. Returns: NumPyro model callable (no required arguments). """ varying_names = space.varying_names fixed_values = space.get_initial_array() prior_scale = math.sqrt(num_shards) sigma_scale = float(noise_scale) * 1.5 * prior_scale # Materialise fixed arrays outside the model closure so they are not # traced as model parameters. contrast_arr = jnp.asarray(fixed_contrast) offset_arr = jnp.asarray(fixed_offset) def model(): """NumPyro model with fixed per-angle contrast and offset.""" params = jnp.asarray(fixed_values) for i, name in enumerate(ALL_PARAM_NAMES): if name in varying_names: prior = space.priors[name] param = numpyro.sample(name, prior.to_numpyro(name)) params = params.at[i].set(param) # contrast/offset are fixed — use scalar mean if 1-D array is passed # so that compute_c2_heterodyne receives a scalar-compatible value. contrast_val = jnp.mean(contrast_arr) offset_val = jnp.mean(offset_arr) if shard_grid is not None: c2_model = compute_c2_elementwise( params, shard_grid, q, dt, phi_angle, contrast_val, offset_val, ) else: c2_model = compute_c2_heterodyne( params, t, q, dt, phi_angle, contrast_val, offset_val, ) n_nan = jnp.sum(~jnp.isfinite(c2_model)) numpyro.deterministic("n_numerical_issues", n_nan) sigma = numpyro.sample("sigma", dist.HalfNormal(scale=sigma_scale)) with numpyro.handlers.mask(mask=_likelihood_boundary_mask(c2_data, shard_grid)): numpyro.sample("obs", dist.Normal(c2_model, sigma), obs=c2_data) return model
[docs] def get_heterodyne_model_constant_averaged( t: jnp.ndarray, q: float, dt: float, phi_angle: float, c2_data: jnp.ndarray, noise_scale: float, space: ParameterSpace, mean_contrast: float, mean_offset: float, shard_grid: ShardGrid | None = None, num_shards: int = 1, ): """Create NumPyro model with a single averaged scaling broadcast to all angles. Both ``mean_contrast`` and ``mean_offset`` are scalars computed from the average over all phi angles. They are treated as fixed (not sampled) and broadcast uniformly. Suitable for ``per_angle_mode="constant_averaged"``. Sigma is sampled internally via ``HalfNormal(noise_scale * 1.5 * sqrt(num_shards))`` for homodyne CMC parity. Args: t: Time array, shape ``(n_t,)``. q: Wavevector magnitude (Å⁻¹). dt: Lag-time step (s). phi_angle: Detector phi angle for this shard (degrees). c2_data: Observed correlation data. noise_scale: Data-driven prior center for sampled ``sigma``. space: Parameter space carrying priors and fixed values. mean_contrast: Scalar speckle contrast averaged over all phi angles. mean_offset: Scalar baseline offset averaged over all phi angles. num_shards: CMC shard count for sigma prior tempering. Default ``1``. Returns: NumPyro model callable (no required arguments). """ varying_names = space.varying_names fixed_values = space.get_initial_array() prior_scale = math.sqrt(num_shards) sigma_scale = float(noise_scale) * 1.5 * prior_scale # Ensure Python floats to avoid accidental JAX tracing at closure time. _contrast = float(mean_contrast) _offset = float(mean_offset) def model(): """NumPyro model with angle-averaged fixed contrast and offset.""" params = jnp.asarray(fixed_values) for i, name in enumerate(ALL_PARAM_NAMES): if name in varying_names: prior = space.priors[name] param = numpyro.sample(name, prior.to_numpyro(name)) params = params.at[i].set(param) if shard_grid is not None: c2_model = compute_c2_elementwise( params, shard_grid, q, dt, phi_angle, _contrast, _offset, ) else: c2_model = compute_c2_heterodyne( params, t, q, dt, phi_angle, _contrast, _offset, ) n_nan = jnp.sum(~jnp.isfinite(c2_model)) numpyro.deterministic("n_numerical_issues", n_nan) sigma = numpyro.sample("sigma", dist.HalfNormal(scale=sigma_scale)) with numpyro.handlers.mask(mask=_likelihood_boundary_mask(c2_data, shard_grid)): numpyro.sample("obs", dist.Normal(c2_model, sigma), obs=c2_data) return model
[docs] def get_heterodyne_model_individual( t: jnp.ndarray, q: float, dt: float, phi_angles: jnp.ndarray, c2_data: jnp.ndarray, noise_scale: float, space: ParameterSpace, contrast_prior_loc: jnp.ndarray | float = 0.5, contrast_prior_scale: float = 0.25, offset_prior_loc: jnp.ndarray | float = 1.0, offset_prior_scale: float = 0.25, shard_grids: list[ShardGrid] | None = None, num_shards: int = 1, ): """Create NumPyro model with per-angle sampled contrast and offset. The most general per-angle model: independently samples ``contrast_i`` and ``offset_i`` for each phi angle using weakly informative Gaussian priors. Suitable for ``per_angle_mode="individual"``. Physical parameters are shared across all angles; the per-angle scaling lives in a ``numpyro.plate`` over the angle dimension. Args: t: Time array, shape ``(n_t,)``. q: Wavevector magnitude (Å⁻¹). dt: Lag-time step (s). phi_angles: Detector phi angles, shape ``(n_phi,)``. c2_data: Observed correlation data, shape ``(n_phi, n_t)``. sigma: Measurement uncertainty — scalar or shape ``(n_phi, n_t)``. space: Parameter space carrying priors and fixed values. contrast_prior_loc: Prior centre(s) for contrast. Scalar or ``(n_phi,)`` array. Default ``0.5``. contrast_prior_scale: Prior width for contrast. Default ``0.25``. offset_prior_loc: Prior centre(s) for offset. Scalar or ``(n_phi,)`` array. Default ``1.0``. offset_prior_scale: Prior width for offset. Default ``0.25``. shard_grids: Optional list of pre-computed ShardGrids, one per phi angle. When provided, uses the memory-efficient element-wise path (no N×N allocation per angle). ``c2_data[ai]`` and ``sigma[ai]`` must then be flattened to match each shard grid's paired indices. Without this, the model builds n_phi N×N matrices per NUTS step which can cause OOM for large datasets. Returns: NumPyro model callable (no required arguments). """ varying_names = space.varying_names fixed_values = space.get_initial_array() phi_arr = jnp.asarray(phi_angles) n_phi = phi_arr.shape[0] if shard_grids is not None and len(shard_grids) != n_phi: raise ValueError( f"shard_grids length {len(shard_grids)} must match n_phi {n_phi}" ) contrast_loc = jnp.broadcast_to(jnp.asarray(contrast_prior_loc), (n_phi,)) offset_loc = jnp.broadcast_to(jnp.asarray(offset_prior_loc), (n_phi,)) prior_scale = math.sqrt(num_shards) sigma_scale = float(noise_scale) * 1.5 * prior_scale def model(): """NumPyro model with per-angle sampled contrast and offset.""" # --- Shared physical parameters --- params = jnp.asarray(fixed_values) for i, name in enumerate(ALL_PARAM_NAMES): if name in varying_names: prior = space.priors[name] param = numpyro.sample(name, prior.to_numpyro(name)) params = params.at[i].set(param) # --- Per-angle scaling sampled in z-space + smooth_bound --- # Homodyne parity: sample in unconstrained z-space, then # transform via smooth_bound (tanh) for NUTS-safe gradients. with numpyro.plate("angles", n_phi): contrast_z = numpyro.sample( "contrast_z", dist.Normal(0.0, 1.0), ) offset_z = numpyro.sample( "offset_z", dist.Normal(0.0, 1.0), ) # Transform: raw = loc + scale * z, then smooth bound to physics range contrast_raw = contrast_loc + contrast_prior_scale * contrast_z contrast_i = smooth_bound(contrast_raw, 0.0, 1.0) numpyro.deterministic("contrast", contrast_i) offset_raw = offset_loc + offset_prior_scale * offset_z offset_i = smooth_bound(offset_raw, 0.5, 1.5) numpyro.deterministic("offset", offset_i) # --- Sigma sampled once and shared across angles (homodyne parity) --- sigma = numpyro.sample("sigma", dist.HalfNormal(scale=sigma_scale)) # --- Likelihood over all angles --- # contrast_i / offset_i have shape (n_phi,); iterate to build # predictions per angle. A vmap would require static phi_arr indexing # which is safe here, but a Python loop keeps tracing simple and avoids # shape-inference issues with dynamic plate sizes. n_total_nan: jnp.ndarray | int = 0 for ai in range(n_phi): if shard_grids is not None: c2_model_i = compute_c2_elementwise( params, shard_grids[ai], q, dt, float(phi_arr[ai]), contrast_i[ai], offset_i[ai], ) else: c2_model_i = compute_c2_heterodyne( params, t, q, dt, float(phi_arr[ai]), contrast_i[ai], offset_i[ai], ) n_total_nan = n_total_nan + jnp.sum(~jnp.isfinite(c2_model_i)) sg_i = shard_grids[ai] if shard_grids is not None else None with numpyro.handlers.mask( mask=_likelihood_boundary_mask(c2_data[ai], sg_i) ): numpyro.sample( f"obs_{ai}", dist.Normal(c2_model_i, sigma), obs=c2_data[ai], ) numpyro.deterministic("n_numerical_issues", n_total_nan) return model
[docs] def get_model_for_mode( per_angle_mode: str, t: jnp.ndarray, q: float, dt: float, phi_angle: float, c2_data: jnp.ndarray, noise_scale: float, space: ParameterSpace, nlsq_result: NLSQResult | None = None, reparam_config: ReparamConfig | None = None, num_shards: int = 1, **kwargs: object, ) -> Callable[[], None]: """Select and build the appropriate NumPyro model based on per-angle mode. Factory that maps ``per_angle_mode`` strings to concrete model constructors. Extra keyword arguments are forwarded to the selected constructor, allowing callers to pass mode-specific parameters (e.g. ``fixed_contrast``, ``mean_contrast``, ``phi_angles``) without branching at the call site. Mapping ------- ``"auto"`` Delegates to :func:`get_heterodyne_model` (sampled contrast/offset from the parameter space) or :func:`get_heterodyne_model_reparam` when ``reparam_config`` is supplied. ``"constant"`` Delegates to :func:`get_heterodyne_model_constant`. Requires ``fixed_contrast`` and ``fixed_offset`` in ``kwargs``. ``"constant_averaged"`` Delegates to :func:`get_heterodyne_model_constant_averaged`. Requires ``mean_contrast`` and ``mean_offset`` in ``kwargs``. ``"individual"`` Delegates to :func:`get_heterodyne_model_individual`. Requires ``phi_angles`` and ``c2_data`` shaped ``(n_phi, n_t)`` in ``kwargs``. Args: per_angle_mode: One of ``"auto"``, ``"constant"``, ``"constant_averaged"``, ``"individual"``. t: Time array. q: Wavevector magnitude (Å⁻¹). dt: Lag-time step (s). phi_angle: Scalar phi angle (used by non-individual modes). c2_data: Observed correlation data. noise_scale: Data-driven prior centre for the sampled ``sigma`` site. space: Parameter space. nlsq_result: Optional NLSQ result for warm-starting (used by ``"auto"`` mode when ``reparam_config`` is supplied). reparam_config: Optional reparameterization config. When provided alongside ``"auto"`` mode, activates the reparam model path. num_shards: CMC shard count for sigma prior tempering. Default ``1``. **kwargs: Mode-specific keyword arguments forwarded verbatim. Returns: NumPyro model callable (no required arguments). Raises: ValueError: If ``per_angle_mode`` is not a recognised string. """ _VALID_MODES = frozenset({"auto", "constant", "constant_averaged", "individual"}) if per_angle_mode not in _VALID_MODES: raise ValueError( f"Unknown per_angle_mode '{per_angle_mode}'. " f"Valid options: {sorted(_VALID_MODES)}" ) if per_angle_mode == "auto": scalings: dict[str, ParameterScaling] | None = kwargs.pop( # type: ignore[assignment] "scalings", None ) contrast: float = float(kwargs.pop("contrast", 1.0)) # type: ignore[arg-type] offset: float = float(kwargs.pop("offset", 1.0)) # type: ignore[arg-type] sg: ShardGrid | None = kwargs.pop("shard_grid", None) # type: ignore[assignment] if reparam_config is not None: return get_heterodyne_model_reparam( t=t, q=q, dt=dt, phi_angle=phi_angle, c2_data=c2_data, noise_scale=noise_scale, space=space, reparam_config=reparam_config, scalings=scalings, contrast=contrast, offset=offset, shard_grid=sg, num_shards=num_shards, ) return get_heterodyne_model( t=t, q=q, dt=dt, phi_angle=phi_angle, c2_data=c2_data, noise_scale=noise_scale, space=space, contrast=contrast, offset=offset, shard_grid=sg, num_shards=num_shards, ) if per_angle_mode == "constant": fixed_contrast = kwargs.pop("fixed_contrast") fixed_offset = kwargs.pop("fixed_offset") sg_const: ShardGrid | None = kwargs.pop("shard_grid", None) # type: ignore[assignment] return get_heterodyne_model_constant( t=t, q=q, dt=dt, phi_angle=phi_angle, c2_data=c2_data, noise_scale=noise_scale, space=space, fixed_contrast=fixed_contrast, # type: ignore[arg-type] fixed_offset=fixed_offset, # type: ignore[arg-type] shard_grid=sg_const, num_shards=num_shards, ) if per_angle_mode == "constant_averaged": mean_contrast = float(kwargs.pop("mean_contrast", 1.0)) # type: ignore[arg-type] mean_offset = float(kwargs.pop("mean_offset", 1.0)) # type: ignore[arg-type] sg_avg: ShardGrid | None = kwargs.pop("shard_grid", None) # type: ignore[assignment] return get_heterodyne_model_constant_averaged( t=t, q=q, dt=dt, phi_angle=phi_angle, c2_data=c2_data, noise_scale=noise_scale, space=space, mean_contrast=mean_contrast, mean_offset=mean_offset, shard_grid=sg_avg, num_shards=num_shards, ) # per_angle_mode == "individual" phi_angles = kwargs.pop("phi_angles") sg_individual: list[ShardGrid] | None = kwargs.pop("shard_grids", None) # type: ignore[assignment] return get_heterodyne_model_individual( t=t, q=q, dt=dt, phi_angles=phi_angles, # type: ignore[arg-type] c2_data=c2_data, noise_scale=noise_scale, space=space, shard_grids=sg_individual, num_shards=num_shards, **kwargs, # type: ignore[arg-type] )
# --------------------------------------------------------------------------- # Sigma estimation # ---------------------------------------------------------------------------
[docs] def estimate_sigma( c2_data: jnp.ndarray, method: str = "diagonal", nlsq_result: NLSQResult | None = None, n_bootstrap: int = 200, bootstrap_seed: int = 0, ) -> jnp.ndarray: """Estimate measurement uncertainty from data. Supported methods: - ``"diagonal"`` -- Uses the standard deviation of the diagonal of ``c2_data`` relative to its mean, floored at 1 % of the data's overall scale. Fast and requires no additional information. - ``"constant"`` -- Returns the overall standard deviation of ``c2_data`` as a scalar. - ``"local"`` -- Computes a spatially smoothed local variance via ``scipy.ndimage.uniform_filter``. Requires SciPy. - ``"residual"`` -- Estimates sigma from the RMS of NLSQ residuals. Requires ``nlsq_result`` with a non-``None`` ``residuals`` field. Falls back to ``"diagonal"`` if residuals are unavailable. - ``"bootstrap"`` -- Draws ``n_bootstrap`` bootstrap replicates of the diagonal and returns the standard deviation of per-replicate means as the noise estimate. Useful when the diagonal has enough points to bootstrap. Args: c2_data: Correlation data, shape ``(n_t,)`` or ``(n_phi, n_t)``. method: Estimation method — one of ``"diagonal"``, ``"constant"``, ``"local"``, ``"residual"``, ``"bootstrap"``. nlsq_result: NLSQ result object. Required (and used) only for ``method="residual"``. n_bootstrap: Number of bootstrap replicates for ``method="bootstrap"``. Default ``200``. bootstrap_seed: JAX PRNG seed for ``method="bootstrap"``. Default ``0``. Returns: Estimated sigma — same shape as ``c2_data`` for ``"local"``, scalar or ``(n_t,)`` array for all other methods. Raises: ValueError: If ``method`` is not a recognised string. """ import jax if method == "diagonal": # Use deviation from diagonal as proxy for noise diag = jnp.diag(c2_data) expected_diag = jnp.mean(diag) sigma = jnp.std(diag - expected_diag) # Floor at 1% of data scale to avoid near-zero sigma for # normalized data where diagonal values are very uniform. # Rule 7: jnp.where preserves gradients below the floor; jnp.maximum # zeros them, which stalls downstream Jacobians and NUTS leapfrog. _std = jnp.std(c2_data) data_scale = jnp.where(_std > 1e-6, _std, 1e-6) _floor = 0.01 * data_scale return jnp.where(sigma > _floor, sigma, _floor) elif method == "constant": # Use overall standard deviation return jnp.std(c2_data) elif method == "local": # Local variance estimation import numpy as np from scipy.ndimage import uniform_filter c2_np = np.asarray(c2_data) mean_local = uniform_filter(c2_np, size=5, mode="reflect") var_local = uniform_filter(c2_np**2, size=5, mode="reflect") - mean_local**2 sigma_np = np.sqrt(np.maximum(var_local, 1e-12)) return jnp.asarray(sigma_np) elif method == "residual": # Estimate sigma from NLSQ residuals when available. if nlsq_result is not None and nlsq_result.residuals is not None: residuals = jnp.asarray(nlsq_result.residuals) rms = jnp.sqrt(jnp.mean(residuals**2)) # Floor at 1 % of data scale for robustness. Rule 7: gradient-safe. _std = jnp.std(c2_data) data_scale = jnp.where(_std > 1e-6, _std, 1e-6) _floor = 0.01 * data_scale return jnp.where(rms > _floor, rms, _floor) # Fall back gracefully so callers don't need to guard against None. return estimate_sigma(c2_data, method="diagonal") elif method == "bootstrap": # Bootstrap estimate of sigma from repeated diagonal measurements. # Draws n_bootstrap replicates of the diagonal with replacement and # uses the standard deviation of replicate means as the noise level. diag = jnp.diag(c2_data) n = diag.shape[0] key = jax.random.PRNGKey(bootstrap_seed) # Draw indices: shape (n_bootstrap, n) key, subkey = jax.random.split(key) indices = jax.random.randint(subkey, shape=(n_bootstrap, n), minval=0, maxval=n) # Replicate means: shape (n_bootstrap,) replicate_means = jnp.mean(diag[indices], axis=1) sigma_boot = jnp.std(replicate_means) # Floor at 0.1 % of data scale (bootstrap can give very small values # when the diagonal is extremely uniform). Rule 7: gradient-safe. _std = jnp.std(c2_data) data_scale = jnp.where(_std > 1e-6, _std, 1e-6) _floor = 0.001 * data_scale return jnp.where(sigma_boot > _floor, sigma_boot, _floor) else: raise ValueError( f"Unknown method '{method}'. Valid options: " "'diagonal', 'constant', 'local', 'residual', 'bootstrap'." )
# --------------------------------------------------------------------------- # Model output validation and parameter counting # ---------------------------------------------------------------------------
[docs] def validate_model_output( c2_theory: jnp.ndarray, params: jnp.ndarray, ) -> bool: """Validate that theoretical C2 values are physically reasonable. Checks for NaN/inf values and enforces the heterodyne C2 range constraint ``[-1.0, 10.0]``. Heterodyne C2 can go negative due to the velocity phase term, unlike homodyne where C2 >= 0. Args: c2_theory: Theoretical C2 array from model evaluation. params: Parameter array used to produce ``c2_theory`` (logged on failure for diagnostics). Returns: ``True`` if the output passes all checks, ``False`` otherwise. """ # Check for NaN values if bool(jnp.any(jnp.isnan(c2_theory))): logger.warning( "validate_model_output: NaN detected in C2 theory (params=%s)", params, ) return False # Check for inf values if bool(jnp.any(jnp.isinf(c2_theory))): logger.warning( "validate_model_output: inf detected in C2 theory (params=%s)", params, ) return False # Enforce heterodyne C2 range: [-1.0, 10.0] c2_min = float(jnp.min(c2_theory)) c2_max = float(jnp.max(c2_theory)) if c2_min < -1.0 or c2_max > 10.0: logger.warning( "validate_model_output: C2 range [%.4e, %.4e] exceeds " "physical bounds [-1.0, 10.0] (params=%s)", c2_min, c2_max, params, ) return False return True
[docs] def get_model_param_count( n_phi: int, per_angle_mode: str = "individual", ) -> int: """Return total number of sampled parameters for the model. Accounts for per-angle mode semantics when counting contrast/offset parameters that are sampled in addition to the shared physics parameters. Per-angle mode contributions: * ``"constant"`` — 0 per-angle params (fixed contrast/offset). * ``"constant_averaged"`` — 0 per-angle params (fixed averaged contrast/offset). * ``"auto"`` — physics params only (contrast/offset live in the parameter space, already counted). * ``"individual"`` — ``2 * n_phi`` per-angle params (``contrast_z`` + ``offset_z`` per angle). Args: n_phi: Number of scattering angles. per_angle_mode: One of ``"constant"``, ``"constant_averaged"``, ``"auto"``, ``"individual"``. Returns: Total number of sampled parameters (int). Raises: ValueError: If ``per_angle_mode`` is not recognised. """ _VALID_MODES = frozenset({"auto", "constant", "constant_averaged", "individual"}) if per_angle_mode not in _VALID_MODES: raise ValueError( f"Unknown per_angle_mode '{per_angle_mode}'. " f"Valid options: {sorted(_VALID_MODES)}" ) # Base: count physics params that vary by default in the registry n_physics = sum( 1 for name in ALL_PARAM_NAMES if DEFAULT_REGISTRY[name].vary_default ) # Per-angle contributions if per_angle_mode == "individual": n_per_angle = 2 * n_phi # contrast_z + offset_z per angle else: # "constant", "constant_averaged", "auto" — no additional sampled params n_per_angle = 0 total = n_physics + n_per_angle logger.debug( "get_model_param_count: n_physics=%d, n_per_angle=%d (mode=%s, n_phi=%d) -> %d", n_physics, n_per_angle, per_angle_mode, n_phi, total, ) return total