Source code for heterodyne.cli.main

"""Main entry point for heterodyne CLI."""

from __future__ import annotations

import os
import sys
import time
from typing import TYPE_CHECKING, Literal

if TYPE_CHECKING:
    pass


def _bootstrap_xla_env(argv: list[str] | None) -> None:
    """Inject ``XLA_FLAGS`` from ``--threads``/``--no-jit`` *before* any
    ``heterodyne`` import triggers ``import jax``.

    ``heterodyne/__init__.py`` imports JAX eagerly to set ``jax_enable_x64``.
    JAX reads ``XLA_FLAGS`` only once during backend initialization, so
    configuring after importing ``heterodyne.cli.args_parser`` was too late
    for ``--threads`` to take effect. We pre-parse the affected flags from
    ``argv`` using only the stdlib and seed the env before any package import.
    """
    raw = list(sys.argv[1:] if argv is None else argv)

    os.environ["JAX_PLATFORM_NAME"] = "cpu"
    os.environ["JAX_ENABLE_X64"] = "1"

    threads: int | None = None
    no_jit = False
    i = 0
    while i < len(raw):
        token = raw[i]
        if token == "--threads" and i + 1 < len(raw):
            try:
                threads = int(raw[i + 1])
            except ValueError:
                pass
            i += 2
            continue
        if token.startswith("--threads="):
            try:
                threads = int(token.split("=", 1)[1])
            except ValueError:
                pass
            i += 1
            continue
        if token == "--no-jit":
            no_jit = True
        i += 1

    if threads is not None:
        existing = os.environ.get("XLA_FLAGS", "")
        tflags = (
            "--xla_cpu_multi_thread_eigen=true"
            f" --intra_op_parallelism_threads={threads}"
        )
        if tflags not in existing:
            os.environ["XLA_FLAGS"] = f"{existing} {tflags}".strip()
        # An explicit ``--threads`` must win over any inherited
        # ``OMP_NUM_THREADS``/``MKL_NUM_THREADS`` (e.g. set high by a batch
        # scheduler or login profile). Using ``setdefault`` here would leave
        # BLAS/OpenMP oversubscribed relative to the XLA intra-op limit, so we
        # assign unconditionally — matching the old ``configure_xla()``.
        os.environ["OMP_NUM_THREADS"] = str(threads)
        os.environ["MKL_NUM_THREADS"] = str(threads)

    if no_jit:
        os.environ["JAX_DISABLE_JIT"] = "1"


[docs] def main(argv: list[str] | None = None) -> int: """Main entry point for heterodyne CLI. Args: argv: Command-line arguments (default: sys.argv[1:]) Returns: Exit code (0 for success) """ # MUST run before the first ``heterodyne`` import below — that import # cascades into ``heterodyne/__init__.py`` which eagerly imports JAX. _bootstrap_xla_env(argv) import logging as _logging # Suppress JAX backend logs (homodyne parity: hide GPU fallback warnings) _logging.getLogger("jax._src.xla_bridge").setLevel(_logging.ERROR) _logging.getLogger("jax._src.compiler").setLevel(_logging.ERROR) from heterodyne.cli.args_parser import create_parser, validate_args parser = create_parser() args = parser.parse_args(argv) # Validate arguments try: warnings = validate_args(args) for warn in warnings: print(f"Warning: {warn}", file=sys.stderr) except FileNotFoundError as e: print(f"Error: {e}", file=sys.stderr) return 1 # Now import JAX-dependent modules from heterodyne.cli.commands import dispatch_command # Set up logging from heterodyne.utils.logging import configure_logging log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] if args.quiet: log_level = "ERROR" elif args.verbose >= 2: log_level = "DEBUG" elif args.verbose >= 1: log_level = "INFO" else: log_level = "WARNING" configure_logging(level=log_level) # Run analysis start_time = time.perf_counter() from heterodyne.utils.logging import get_logger, log_exception logger = get_logger(__name__) try: exit_code = dispatch_command(args) except KeyboardInterrupt: logger.info("Analysis interrupted by user") return 130 except Exception as e: log_exception(logger, e, context={"command": "main"}) return 1 elapsed = time.perf_counter() - start_time if not args.quiet: logger.info("Analysis completed in %.1f seconds", elapsed) return exit_code
[docs] def main_hexp() -> int: """Entry point for ``hexp`` — plot experimental data.""" return main(["--plot-experimental-data"] + sys.argv[1:])
[docs] def main_hsim() -> int: """Entry point for ``hsim`` — plot simulated data.""" return main(["--plot-simulated-data"] + sys.argv[1:])
if __name__ == "__main__": sys.exit(main())