Source code for kspecdr.fluxcal.containers

"""
Data containers for kspecdr flux calibration.

All containers are plain dataclasses with light validation in ``__post_init__``.
No heavy dependencies — only numpy.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Optional

import numpy as np


__all__ = [
    "Spectrum1D",
    "Photometry",
    "FilterCurve",
    "StellarTemplate",
    "CalibrationVector",
    "FluxCalibrationResult",
]


[docs] @dataclass class Spectrum1D: """Wavelength-calibrated 1D spectrum with uncertainty and pixel mask. Parameters ---------- wavelength : ndarray, shape (N,) Wavelength axis in Angstrom. flux : ndarray, shape (N,) Flux values (counts, counts/s, or physical flux density depending on pipeline stage). variance : ndarray, shape (N,) Per-pixel variance, same units as ``flux**2``. mask : ndarray of bool, shape (N,) True for *good* (usable) pixels. meta : dict Freeform metadata. Recognised keys: ``fiber_id``, ``fiber_name``, ``exptime``, ``bunit``, ``continuum`` (ndarray) — theoretical continuum carried from BOSZ, ``rv_kms`` — applied radial-velocity correction. """ wavelength: np.ndarray flux: np.ndarray variance: np.ndarray mask: np.ndarray meta: Dict = field(default_factory=dict) def __post_init__(self) -> None: self.wavelength = np.asarray(self.wavelength, dtype=float) self.flux = np.asarray(self.flux, dtype=float) self.variance = np.asarray(self.variance, dtype=float) self.mask = np.asarray(self.mask, dtype=bool) n = len(self.wavelength) if not (len(self.flux) == len(self.variance) == len(self.mask) == n): raise ValueError( "wavelength, flux, variance, and mask must all have the same length" ) @property def ivar(self) -> np.ndarray: """Inverse variance; zero for masked or non-positive variance pixels.""" good = self.mask & (self.variance > 0) iv = np.zeros_like(self.variance) iv[good] = 1.0 / self.variance[good] return iv @property def n_pixels(self) -> int: return len(self.wavelength) @property def wave_range(self) -> tuple[float, float]: """(wave_min, wave_max) in Angstrom.""" return float(self.wavelength[0]), float(self.wavelength[-1]) @property def n_good(self) -> int: """Number of unmasked pixels.""" return int(self.mask.sum())
[docs] @dataclass class Photometry: """Broadband photometric measurements for a single standard star. All magnitudes are stored in the **AB system** regardless of the native photometric system of the catalogue. The conversion from Vega to AB is applied when the object is constructed (see :func:`~kspecdr.fluxcal.photometry.photometry_from_catalog_row`). Parameters ---------- filter_names : list of str Names matching the keys in :data:`~kspecdr.fluxcal.photometry.FILTER_INFO` (e.g. ``"ps1_g"``, ``"gaia_g"``). magnitudes : ndarray, shape (N_bands,) AB magnitudes. ``nan`` marks unavailable bands. mag_errors : ndarray, shape (N_bands,) Magnitude errors (1-sigma). ``nan`` or non-positive marks unavailable. meta : dict Optional metadata: ``ra``, ``dec``, ``objid``, ``teff_catalog``, ``a_gaia`` (extinction), ``a_g`` (extinction in g-band). """ filter_names: List[str] magnitudes: np.ndarray mag_errors: np.ndarray meta: Dict = field(default_factory=dict) def __post_init__(self) -> None: self.magnitudes = np.asarray(self.magnitudes, dtype=float) self.mag_errors = np.asarray(self.mag_errors, dtype=float) if len(self.filter_names) != len(self.magnitudes): raise ValueError("filter_names and magnitudes must have the same length")
[docs] def get_band(self, name: str) -> tuple[float, float]: """Return ``(mag, mag_err)`` for the named filter. Raises ``KeyError`` if *name* is not in this object's filter list. """ try: idx = self.filter_names.index(name) except ValueError: raise KeyError(name) from None return float(self.magnitudes[idx]), float(self.mag_errors[idx])
[docs] def valid_bands(self) -> List[str]: """Filter names with finite magnitude *and* positive, finite error.""" return [ n for n, m, e in zip(self.filter_names, self.magnitudes, self.mag_errors) if np.isfinite(m) and np.isfinite(e) and e > 0.0 ]
def __len__(self) -> int: return len(self.filter_names)
[docs] @dataclass class FilterCurve: """Transmission curve for one photometric bandpass. Parameters ---------- name : str Identifier matching a file in ``data/filters/`` (without ``.dat``). wavelength : ndarray, shape (M,) Wavelength in **Angstrom** (converted from µm on load). transmission : ndarray, shape (M,) Fractional transmission, 0–1. system : str Photometric system: ``"AB"`` or ``"Vega"``. vega_to_ab : float Offset to add to a Vega magnitude to obtain an AB magnitude: ``m_AB = m_Vega + vega_to_ab``. Zero for AB-system filters. """ name: str wavelength: np.ndarray transmission: np.ndarray system: str = "AB" vega_to_ab: float = 0.0 def __post_init__(self) -> None: self.wavelength = np.asarray(self.wavelength, dtype=float) self.transmission = np.asarray(self.transmission, dtype=float) @property def wave_eff(self) -> float: """Transmission-weighted mean wavelength in Angstrom.""" return float( np.trapz(self.wavelength * self.transmission, self.wavelength) / np.trapz(self.transmission, self.wavelength) ) @property def wave_range(self) -> tuple[float, float]: """Wavelength range where transmission > 1 % of peak (Angstrom).""" thresh = 0.01 * self.transmission.max() above = self.wavelength[self.transmission > thresh] return float(above[0]), float(above[-1])
[docs] @dataclass class StellarTemplate: """A single BOSZ 2024 stellar model spectrum. Parameters ---------- wavelength : ndarray, shape (K,) Wavelength in Angstrom (native BOSZ log-λ grid at R = 10,000). flux : ndarray, shape (K,) Surface flux density in erg/s/cm²/Å (= 4π × H, column 3 of file). continuum : ndarray, shape (K,) Theoretical continuum in the same units (= 4π × C, column 4 of file). Useful as a first-pass normalization without fitting a spline. teff : float Effective temperature in K. logg : float Surface gravity log(g) in cgs. feh : float Overall metallicity [M/H]. alpha_m : float Alpha-element enhancement [α/M]. carbon_m : float Carbon enhancement [C/M]. vmicro : float Microturbulent velocity in km/s. atmos_model : str Atmosphere model code: ``"mp"`` (MARCS plane-parallel) or ``"ap"`` (ATLAS9 plane-parallel). source : str Provenance string — typically the original filename. """ wavelength: np.ndarray flux: np.ndarray continuum: np.ndarray teff: float logg: float feh: float alpha_m: float = 0.0 carbon_m: float = 0.0 vmicro: float = 1.0 atmos_model: str = "" source: str = "BOSZ2024" def __post_init__(self) -> None: self.wavelength = np.asarray(self.wavelength, dtype=float) self.flux = np.asarray(self.flux, dtype=float) self.continuum = np.asarray(self.continuum, dtype=float)
[docs] def to_spectrum1d(self) -> Spectrum1D: """Wrap as :class:`Spectrum1D` (variance = 0, mask = all good).""" return Spectrum1D( wavelength=self.wavelength.copy(), flux=self.flux.copy(), variance=np.zeros_like(self.flux), mask=np.ones(len(self.flux), dtype=bool), meta={ "teff": self.teff, "logg": self.logg, "feh": self.feh, "alpha_m": self.alpha_m, "source": self.source, "continuum": self.continuum.copy(), }, )
@property def label(self) -> str: """Short human-readable parameter label.""" return ( f"Teff={self.teff:.0f} logg={self.logg:.1f} " f"[M/H]={self.feh:+.2f} [α/M]={self.alpha_m:+.2f}" )
[docs] @dataclass class CalibrationVector: """Wavelength-dependent factor converting observed counts to physical flux. ``flux_calibrated[i] = counts[i] * cal_factor[i]`` Parameters ---------- wavelength : ndarray, shape (N,) Wavelength axis in Angstrom. cal_factor : ndarray, shape (N,) Calibration factor in erg/s/cm²/Å per count (or per count/s). cal_variance : ndarray, shape (N,) Variance on ``cal_factor``. mask : ndarray of bool, shape (N,) True for *reliable* (unmasked) pixels. Telluric bands and other problematic regions are set to False. meta : dict Per-star metadata. Recognised keys: ``star_name``, ``fiber_id``, ``teff``, ``logg``, ``feh``, ``alpha_m``, ``scale_factor`` (float) — photometric normalisation, ``band_residuals`` (dict) — synth − obs per filter, ``chi2``, ``ndof``, ``rv_kms``. """ wavelength: np.ndarray cal_factor: np.ndarray cal_variance: np.ndarray mask: np.ndarray meta: Dict = field(default_factory=dict) def __post_init__(self) -> None: self.wavelength = np.asarray(self.wavelength, dtype=float) self.cal_factor = np.asarray(self.cal_factor, dtype=float) self.cal_variance = np.asarray(self.cal_variance, dtype=float) self.mask = np.asarray(self.mask, dtype=bool) @property def cal_error(self) -> np.ndarray: """1-sigma uncertainty on ``cal_factor``.""" return np.sqrt(np.where(self.cal_variance >= 0, self.cal_variance, 0.0)) @property def n_good(self) -> int: return int(self.mask.sum())
[docs] @dataclass class FluxCalibrationResult: """Complete output of the flux calibration procedure for one exposure. Parameters ---------- combined_vector : CalibrationVector Final combined calibration curve to apply to all science fibers. per_star_vectors : list of CalibrationVector Individual per-star calibration vectors before combination. per_star_residuals : list of ndarray Fractional residual arrays ``(Cal_star − Cal_combined) / Cal_combined`` for each star. Shape (N,) each. summary : dict Exposure-level summary. Keys: ``n_stars_used`` (int), ``n_stars_rejected`` (int), ``rms_scatter`` (float, fractional), ``wavelength_range`` (tuple), ``per_star_metrics`` (list of dict). """ combined_vector: CalibrationVector per_star_vectors: List[CalibrationVector] per_star_residuals: List[np.ndarray] summary: Dict = field(default_factory=dict) @property def n_stars(self) -> int: return len(self.per_star_vectors)