Source code for heterodyne.io.json_utils
"""JSON serialization utilities for JAX arrays and numpy types."""
from __future__ import annotations
import json
import math
import os
import tempfile
from pathlib import Path
from typing import Any, cast
import numpy as np
def _sanitize_float(value: float) -> float:
"""Check a float for NaN/Inf and raise ValueError.
Args:
value: Float value to check
Returns:
The original value if finite
Raises:
ValueError: If value is NaN or Inf
"""
if math.isnan(value) or math.isinf(value):
raise ValueError(f"Cannot serialize non-finite float to JSON: {value!r}")
return value
[docs]
def json_safe(obj: Any) -> Any:
"""Convert object to JSON-serializable form.
Handles:
- JAX arrays -> lists
- numpy arrays -> lists
- numpy scalars -> Python scalars
- complex numbers -> {"real": ..., "imag": ...} dicts
- Path objects -> strings
- datetime -> ISO format strings
- Nested dicts/lists recursively
Args:
obj: Object to convert
Returns:
JSON-serializable equivalent
"""
# Handle JAX arrays via proper isinstance check
try:
import jax
if isinstance(obj, jax.Array):
return json_safe(np.asarray(obj))
except ImportError:
pass
# Handle complex numpy arrays — preserve shape for round-trip
if isinstance(obj, np.ndarray) and np.issubdtype(obj.dtype, np.complexfloating):
return {
"__complex_array__": True,
"shape": list(obj.shape),
"data": [{"real": float(z.real), "imag": float(z.imag)} for z in obj.flat],
}
# Handle numpy arrays — recurse so NaN/Inf elements hit _sanitize_float
if isinstance(obj, np.ndarray):
return json_safe(obj.tolist())
# Handle numpy complex scalars
if isinstance(obj, np.complexfloating):
return {"real": obj.real.item(), "imag": obj.imag.item()}
# Handle Python complex
if isinstance(obj, complex):
return {"real": obj.real, "imag": obj.imag}
# Handle numpy scalar types — floating scalars go through _sanitize_float
if isinstance(obj, np.floating):
return _sanitize_float(float(obj))
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.bool_):
return bool(obj)
# Handle Path objects
if isinstance(obj, Path):
return str(obj)
# Handle datetime
from datetime import datetime
if isinstance(obj, datetime):
return obj.isoformat()
# Handle nested structures
if isinstance(obj, dict):
return {k: json_safe(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [json_safe(item) for item in obj]
# Return primitives as-is (with NaN/Inf check for floats)
if obj is None or isinstance(obj, (bool, int, str)):
return obj
if isinstance(obj, float):
_sanitize_float(obj)
return obj
# Fallback: try string conversion
try:
return str(obj)
except Exception:
return f"<non-serializable: {type(obj).__name__}>"
[docs]
def json_serializer(obj: Any) -> str:
"""Serialize object to JSON string with pretty formatting.
Args:
obj: Object to serialize
Returns:
Pretty-printed JSON string
"""
return json.dumps(json_safe(obj), indent=2, allow_nan=False)
[docs]
def load_json(path: Path | str) -> dict[str, Any]:
"""Load JSON file.
Args:
path: Path to JSON file
Returns:
Parsed JSON data
"""
with open(path, encoding="utf-8") as f:
return cast(dict[str, Any], json.load(f))
[docs]
def save_json(data: Any, path: Path | str) -> None:
"""Save data to JSON file with pretty formatting.
Uses atomic write (write-to-temp + rename) to prevent partial writes.
Args:
data: Data to save
path: Output path
"""
output_path = Path(path)
output_path.parent.mkdir(parents=True, exist_ok=True)
fd, tmp_path = tempfile.mkstemp(dir=str(output_path.parent), suffix=".tmp")
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
f.write(json_serializer(data))
os.replace(tmp_path, str(output_path))
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise