"""Numerically safe mathematical primitives for heterodyne physics.
All functions are designed to be compatible with both NumPy and JAX
arrays, avoiding NaN/Inf from edge cases (division by zero,
overflow in exp, negative bases in power).
Shared utilities used by both the NLSQ meshgrid path and the CMC
element-wise path:
- ``trapezoid_cumsum``: O(dt²) cumulative integral
- ``create_time_integral_matrix``: N×N from cumsum (NLSQ only)
- ``smooth_abs``: gradient-safe ``|x|`` for NUTS
- ``compute_transport_rate``: J(t) = D0·t^α + offset
- ``compute_velocity_rate``: v(t) = v0·t^β + v_offset
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
import numpy as np
# ---------------------------------------------------------------------------
# Numerically safe math primitives
# ---------------------------------------------------------------------------
[docs]
def safe_exp(x: jnp.ndarray | np.ndarray, limit: float = 500.0) -> jnp.ndarray:
"""Exponential with overflow protection.
Clips the argument to [-limit, limit] before computing exp()
to avoid Inf outputs. The default limit of 500 gives exp(500) ≈ 1.4e217
which is within float64 range.
Args:
x: Input array
limit: Clipping threshold (symmetric)
Returns:
exp(clip(x)), same shape as x
"""
x = jnp.asarray(x)
return jnp.exp(jnp.clip(x, -limit, limit))
[docs]
def safe_power(base: jnp.ndarray | np.ndarray, exponent: float) -> jnp.ndarray:
"""Power function safe for non-positive bases.
For base ≤ 0, returns 0.0 (the physical limit for t^α transport).
For base > 0, returns base^exponent normally.
Args:
base: Base array (typically time values)
exponent: Power exponent
Returns:
Safe power result, same shape as base
"""
base = jnp.asarray(base)
# Use jnp.where instead of jnp.maximum to preserve gradients below the
# floor: jnp.maximum zeros the gradient when base < 1e-30, which stalls
# the NLSQ Jacobian and NUTS leapfrog steps.
base_safe = jnp.where(base > 1e-30, base, 1e-30)
result = jnp.power(base_safe, exponent)
return jnp.where(base > 0, result, 0.0)
[docs]
def safe_divide(
numerator: jnp.ndarray | np.ndarray,
denominator: jnp.ndarray | np.ndarray,
fill: float = 0.0,
min_denom: float = 1e-30,
) -> jnp.ndarray:
"""Division with protection against zero/near-zero denominators.
Args:
numerator: Dividend array
denominator: Divisor array
fill: Value to return where denominator is too small
min_denom: Minimum absolute denominator value
Returns:
Safe quotient, same shape as inputs
"""
num = jnp.asarray(numerator)
den = jnp.asarray(denominator)
# Preserve sign of denominator for the floor value; use jnp.where to
# avoid sign(0.0)=0 which would produce safe_den=0 and intermediate NaN.
floor = jnp.where(den >= 0, min_denom, -min_denom)
safe_den = jnp.where(jnp.abs(den) > min_denom, den, floor)
# Where original denominator was ~0, return fill value
result = num / safe_den
return jnp.where(jnp.abs(den) > min_denom, result, fill)
[docs]
def safe_log(x: jnp.ndarray | np.ndarray, floor: float = 1e-30) -> jnp.ndarray:
"""Logarithm with protection against non-positive arguments.
Args:
x: Input array
floor: Minimum value before taking log
Returns:
log(max(x, floor)), same shape as x
"""
x = jnp.asarray(x)
# Use jnp.where to preserve gradients: jnp.maximum zeros the gradient
# when x < floor, stalling log-space parameter updates.
return jnp.log(jnp.where(x > floor, x, floor))
[docs]
def safe_sqrt(x: jnp.ndarray | np.ndarray) -> jnp.ndarray:
"""Square root with protection against negative arguments.
Args:
x: Input array
Returns:
sqrt(max(x, 0)), same shape as x
"""
x = jnp.asarray(x)
# Use jnp.where to preserve gradients: jnp.maximum zeros the gradient
# when x < 0, which would stall the Jacobian at the sqrt floor.
return jnp.sqrt(jnp.where(x > 0.0, x, 0.0))
[docs]
def compute_relative_difference(
a: jnp.ndarray | np.ndarray,
b: jnp.ndarray | np.ndarray,
) -> jnp.ndarray:
"""Compute element-wise relative difference ``|a - b|`` / max(``|a|``, ``|b|``, 1e-10).
Useful for comparing correlation matrices or parameter arrays
where absolute differences may mislead at different scales.
Args:
a: First array
b: Second array
Returns:
Relative difference array, values in [0, 2]
"""
a, b = jnp.asarray(a), jnp.asarray(b)
max_abs = jnp.maximum(jnp.maximum(jnp.abs(a), jnp.abs(b)), 1e-10)
return jnp.abs(a - b) / max_abs
[docs]
def symmetrize(matrix: jnp.ndarray | np.ndarray) -> jnp.ndarray:
"""Force a matrix to be exactly symmetric: (M + M^T) / 2.
Args:
matrix: Square matrix
Returns:
Symmetric matrix
"""
m = jnp.asarray(matrix)
return 0.5 * (m + m.T)
# ---------------------------------------------------------------------------
# Shared integral and rate primitives (used by both NLSQ and CMC paths)
# ---------------------------------------------------------------------------
[docs]
def smooth_abs(x: jnp.ndarray, eps: float = 1e-12) -> jnp.ndarray:
"""Gradient-safe absolute value: sqrt(x² + ε).
``jnp.abs(x)`` has undefined gradient at x=0, which causes NaN
in NUTS backpropagation on matrix diagonals where integrals are
zero. This smooth approximation matches ``|x|`` to O(√ε) and
has well-defined gradients everywhere.
Args:
x: Input array.
eps: Smoothing parameter. 1e-12 gives ~1e-6 bias on diagonal.
Returns:
Smooth ``|x|``, same shape as x.
"""
return jnp.sqrt(x**2 + eps)
[docs]
def smooth_clip(
x: jnp.ndarray,
low: float,
high: float,
sharpness: float = 50.0,
) -> jnp.ndarray:
"""Soft clip to ``[low, high]`` with continuous gradient at the boundaries.
Acts as the identity in the interior and softplus-smoothed at the
boundaries. Use this for physical bounds (e.g. sample fraction in
[0, 1]) where a hard ``jnp.clip`` would zero the gradient and stall
NUTS leapfrog integration or NLSQ Jacobian descent (CLAUDE.md rule #7).
The boundary smear scales as ``1/sharpness`` — at the default value
``sharpness=50`` the boundary lands within ~``(high-low)/50`` of the
target (≈2% of the range), with monotonic identity in the interior.
Raise ``sharpness`` for a tighter approximation at the cost of
gradient magnitude near the boundary; lower it for stronger
regularisation.
Args:
x: Input array (any shape).
low: Lower physical bound (inclusive in the limit).
high: Upper physical bound (inclusive in the limit).
sharpness: Softplus sharpness; default 50 gives ~2% boundary smear.
Returns:
Smoothly bounded array, asymptotically in (low, high), with
well-defined gradients everywhere.
"""
k = sharpness
# Smooth max(x, low): identity for x >> low, → low for x << low
x_lo = low + jax.nn.softplus(k * (x - low)) / k
# Smooth min(x_lo, high): identity for x_lo << high, → high for x_lo >> high
return high - jax.nn.softplus(k * (high - x_lo)) / k
[docs]
def trapezoid_cumsum(f: jnp.ndarray, dt: float | jnp.ndarray) -> jnp.ndarray:
"""Trapezoidal cumulative integral with O(dt²) accuracy.
Computes cumsum[0] = 0, cumsum[k] = Σ_{i=0}^{k-1} (f[i]+f[i+1])/2 × dt.
This matches homodyne's ``trapezoid_cumsum`` pattern. The dt factor
is included in the returned values (unlike homodyne which factors it
out into the wavevector prefactor).
Args:
f: Function values at uniformly spaced time points, shape (N,).
dt: Time step.
Returns:
Cumulative integral, shape (N,). cumsum[0] = 0 always.
"""
midpoints = (f[:-1] + f[1:]) / 2.0
return jnp.concatenate([jnp.zeros(1), jnp.cumsum(midpoints) * dt])
[docs]
def create_time_integral_matrix(cumsum_values: jnp.ndarray) -> jnp.ndarray:
"""Build N×N integral matrix from cumulative sums (NLSQ meshgrid path).
M[i,j] = cumsum[j] - cumsum[i] (signed difference).
For transport integrals, call ``smooth_abs`` on the result to get
direction-independent decay. For velocity integrals, use the signed
result directly (it feeds into ``cos(q cos(φ) ∫v dt)``).
Args:
cumsum_values: Cumulative integral, shape (N,).
Returns:
Signed integral matrix, shape (N, N).
"""
return cumsum_values[None, :] - cumsum_values[:, None]
[docs]
def compute_transport_rate(
t: jnp.ndarray,
D0: float | jnp.ndarray,
alpha: float | jnp.ndarray,
offset: float | jnp.ndarray,
) -> jnp.ndarray:
"""Transport rate function J(t) = D0·t^α + offset.
Shared by both NLSQ and CMC paths — the rate function is the same,
only the integral evaluation strategy differs.
Args:
t: Time array, shape (N,).
D0: Transport prefactor (Ų/s^α).
alpha: Transport exponent (dimensionless).
offset: Constant rate offset (Ų/s).
Returns:
Rate values, shape (N,), floored at 0.
"""
# t_safe: prevent NaN in jnp.power when t=0 with negative alpha.
# t is a data array (not a parameter), so jnp.where here does not affect
# the parameter gradient; the outer jnp.where(t > 0) handles t=0 exactly.
t_safe = jnp.where(t > 1e-10, t, 1e-10)
t_power = jnp.where(t > 0, jnp.power(t_safe, alpha), 0.0)
rate = D0 * t_power + offset
# Physical positivity floor. jnp.maximum's JVP averages the two tangents
# at the kink (0.5x), matching the FD subgradient of D_offset; rewriting
# to jnp.where(rate >= 0, rate, 0) routes the full tangent at the
# boundary (2x) and breaks the FD↔autodiff agreement pinned by
# test_gradient_finite_difference. Allow-listed in
# tests/unit/core/test_no_gradient_killing_clip.py.
return jnp.maximum(rate, 0.0)
[docs]
def compute_velocity_rate(
t: jnp.ndarray,
v0: float | jnp.ndarray,
beta: float | jnp.ndarray,
v_offset: float | jnp.ndarray,
) -> jnp.ndarray:
"""Velocity rate function v(t) = v0·t^β + v_offset.
Unlike transport rate, the velocity is NOT floored at 0 because
the velocity integral enters as cos(q·cos(φ)·∫v dt) which is
naturally bounded.
Args:
t: Time array, shape (N,).
v0: Velocity prefactor (Å/s^β).
beta: Velocity exponent (dimensionless).
v_offset: Constant velocity offset (Å/s).
Returns:
Velocity values, shape (N,).
"""
# Use jnp.where instead of jnp.maximum to preserve gradients below the
# t=0 floor (jnp.maximum zeros the gradient there).
t_safe = jnp.where(t > 1e-10, t, 1e-10)
t_power = jnp.where(t > 0, jnp.power(t_safe, beta), 0.0)
return v0 * t_power + v_offset
[docs]
@jax.jit
def safe_sinc(x: jnp.ndarray) -> jnp.ndarray:
"""Unnormalized sinc function sin(x)/x, safe at x=0.
Returns 1.0 at x=0 (the mathematical limit).
Args:
x: Input array (radians, unnormalized).
Returns:
sin(x)/x with sinc(0) = 1.
"""
x = jnp.asarray(x)
x_safe = jnp.where(jnp.abs(x) > 1e-10, x, 1.0)
result = jnp.sin(x_safe) / x_safe
return jnp.where(jnp.abs(x) > 1e-10, result, 1.0)