"""
Reduce Object Module
This module implements the top-level ``reduce_object`` routine for the KSPEC
pipeline, orchestrating the reduction of a raw science file to produce im(age),
ex(tracted), and red(uced) science files.
"""
import logging
import shutil
from pathlib import Path
from typing import Dict, Any
import numpy as np
from .preproc.make_im import make_im
from .extract.make_ex import make_ex
from .io.image import ImageFile
from .wavecal.scrunch import scrunch_from_arc_id
from .utils.args import init_args, validate_reduce_object_args
logger = logging.getLogger(__name__)
[docs]
def reduce_object(args: Dict[str, Any]) -> None:
"""
Reduce a raw science file to produce im, ex, and red science files.
Parameters
----------
args : dict
Dictionary containing reduction arguments:
Required keys:
- ``RAW_FILENAME``: Input raw filename
- ``IMAGE_FILENAME``: Output IM filename
- ``EXTRAC_FILENAME``: Output extracted filename
- ``OUTPUT_FILENAME``: Output reduced filename
- ``TLMAP_FILENAME``: Tramline map filename
- ``WAVEL_FILENAME``: Wavelength calibration (arc RED) filename
Optional keys:
- ``FFLAT_FILENAME``: Fiber flat filename
- ``OUT_DIRNAME``: Output directory
- ``DPCRREX``: Double pass cosmic ray rejection (bool)
- ``EXTR_OPERATION``: Extraction method (default ``"SUM"``)
- ``OPTEX_MKRES``: Make residual map for optimal extraction (bool)
- ``VERBOSE``: Verbosity (bool, default True)
- ``USE_GENCAL``: Use skyline recalibration (bool)
- ``TST_SKYCAL``: Test skyline calibration (bool)
- ``INC_RWSS``: Include Reduced Without Sky Subtraction copy (bool)
- ``SKYSPRSMP``: Super sky subtraction (bool)
- ``SKYSUB``: Enable sky subtraction (bool, default True)
- ``SKYCOMBINE``: Sky fiber combination method
(``'MEAN'`` | ``'MEDIAN'`` | ``'SIGCLIP'``, default ``'MEAN'``)
- ``SKYCOMBINE_SIGMA``: Sigma-clipping threshold for ``SIGCLIP``
(float, default 3.0)
- ``SKYCOMBINE_ITERS``: Max iterations for ``SIGCLIP``
(int, default 5)
- ``SKYSUB_PCA``: Enable PCA sky subtraction (bool)
- ``CALIBFLUX``: Enable flux calibration (bool)
- ``TELCOR``: Enable telluric correction (bool)
- ``VELCOR``: Enable velocity correction (bool)
- ``TRANSFUNC``: Transfer function correction (bool)
- ``DEWIGGLE``: De-wiggle (bool)
"""
# --- Initialisation ---
init_args(args)
validate_reduce_object_args(args)
verbose = args.get('VERBOSE', True)
# --- Create IM frame from raw ---
raw_fname = args.get('RAW_FILENAME')
im_fname = args.get('IMAGE_FILENAME')
if raw_fname:
make_im(raw_fname, im_fname, **args)
else:
logger.warning(
"RAW_FILENAME not in args; skipping MAKE_IM "
"(assuming IM file already exists)"
)
# --- Double-pass cosmic ray rejection (requires OPTEX + residual map) ---
dbl_pass_crr_extr = args.get('DPCRREX', False)
operat = args.get('EXTR_OPERATION', '')
make_res = args.get('OPTEX_MKRES', False)
is_optex_based = operat in ('OPTEX', 'SCMOPTEX', 'SMCOPTEX')
if dbl_pass_crr_extr and is_optex_based and make_res:
logger.info("Performing double-pass cosmic ray rejection extraction")
make_ex(args)
_clean_im(args)
elif dbl_pass_crr_extr:
logger.warning(
"DPCRREX requested but OPTEX or OPTEX_MKRES not selected — ignoring"
)
# --- Create EX frame from IM ---
make_ex(args)
ex_filename = args.get('EXTRAC_FILENAME')
red_filename = args.get('OUTPUT_FILENAME')
if not ex_filename or not red_filename:
raise ValueError("EXTRAC_FILENAME and OUTPUT_FILENAME must be specified.")
if verbose:
logger.info("=" * 50)
logger.info("Reducing object spectra from extraction file")
logger.info("=" * 50)
logger.info("Extraction file = %s", ex_filename)
# --- Skyline recalibration (if requested) ---
if args.get('USE_GENCAL', False):
_skylines_recalibration(ex_filename, args)
# --- Create RED frame by copying EX ---
logger.info("Creating RED file %s from %s", red_filename, ex_filename)
shutil.copyfile(ex_filename, red_filename)
# --- Skyline calibration test ---
if args.get('USE_GENCAL', False) and args.get('TST_SKYCAL', False):
_skycalib_test(red_filename, args)
# --- Divide by fiber flat-field (if flat file provided) ---
_flatfield(red_filename, args)
# --- Scrunch (rebin to linear wavelength grid) ---
_scrunch(red_filename, args)
# --- Fiber throughput calibration and sky subtraction ---
# (not applicable for Nod & Shuffle data)
is_nod_shuffle = _check_nod_shuffle(red_filename)
if not is_nod_shuffle:
_throughput_calibrate(red_filename, args)
if args.get('INC_RWSS', False):
_make_rwss(red_filename)
if not is_nod_shuffle:
if args.get('SKYSUB', True):
_skysub(red_filename, args)
if args.get('SKYSPRSMP', False):
_super_skysub(red_filename, ex_filename, args)
# --- Clean up intermediate PIXCAL HDU ---
_delete_pixcal(red_filename)
# --- Telluric correction ---
if args.get('TELCOR', False):
_telluric_correct(red_filename, args)
# --- Velocity correction ---
if args.get('VELCOR', False):
_velocity_correct(red_filename, args)
# --- PCA sky subtraction ---
if not is_nod_shuffle and args.get('SKYSUB_PCA', False):
_skysub_pca(red_filename, args)
# --- Flux calibration ---
if args.get('CALIBFLUX', False):
_apply_fluxcal(red_filename, args)
# --- Transfer function correction ---
if args.get('TRANSFUNC', False):
_apply_transfer_function(red_filename, args)
# --- De-wiggle ---
if args.get('DEWIGGLE', False):
_dewiggle(red_filename, args)
# --- Finalize: write metadata and mark as reduced ---
_write_reduction_args(red_filename, args)
_set_reduced_status(red_filename)
_stamp_pipeline_version(red_filename)
logger.info("Object frame reduced")
if verbose:
logger.info("Reduction file %s created.", red_filename)
# ---------------------------------------------------------------------
# Flux Calibration
# ---------------------------------------------------------------------
def _apply_fluxcal(red_filename: str, args: Dict[str, Any]) -> None:
"""Apply spectrophotometric flux calibration to a reduced frame.
Identifies standard-star fibers (TYPE='C'), matches to BOSZ templates,
derives per-star calibration vectors, combines them, and applies the
result to all fibers. Writes back to the RED file in place.
Parameters
----------
red_filename : str
Path to the reduced FITS file (modified in place).
args : dict
Reduction arguments. Relevant keys:
- ``CALIBFLUX_CATALOG`` : str — path to standard-star CSV catalog
- ``CALIBFLUX_FWHM`` : float — instrument FWHM in Å (default: from header SPECFWHM)
- ``CALIBFLUX_METRIC`` : str — scoring metric (default: ``"chi2"``)
- ``CALIBFLUX_SMOOTH`` : bool — smooth combined vector (default: False)
"""
from .constants import FIBER_TYPE_CALIBRATION
from .io.image import ImageFile
from .fluxcal.calibration import (
compute_calibration_vector_for_star,
combine_calibration_vectors,
apply_flux_calibration,
)
from .fluxcal.photometry import (
load_filter_curves,
load_standard_star_catalog,
photometry_from_catalog_row,
DEFAULT_BANDS,
)
from .fluxcal.templates import TemplateLibrary
from .fluxcal.masks import load_mask_regions
from .fluxcal.containers import Spectrum1D
import numpy as np
# --- Load the RED file ---
with ImageFile(red_filename, mode='UPDATE') as red_file:
spectra = red_file.read_image_data() # (NFIB, NPIX) or (NPIX, NFIB)
variance = red_file.read_variance_data()
fiber_types, nf = red_file.read_fiber_types(1000)
wave_data = red_file.read_wave_data() # wavelength solution
# Determine wavelength axis (1-D common grid for scrunched data)
if wave_data is not None and wave_data.ndim == 1:
wavelength = wave_data
elif wave_data is not None and wave_data.ndim == 2:
wavelength = wave_data[0] # use first fiber's wavelength
else:
nx, _ = red_file.get_size()
wavelength = np.arange(nx, dtype=float)
logger.warning("No wavelength solution found; using pixel indices")
# Determine data layout
if spectra.shape[0] == len(wavelength):
# (NPIX, NFIB) — transpose to (NFIB, NPIX) for processing
spectra = spectra.T
variance = variance.T
layout = "npix_nfib"
else:
layout = "nfib_npix"
nfib, npix = spectra.shape
# --- Identify standard-star fibers ---
std_indices = [
i for i in range(min(nfib, len(fiber_types)))
if fiber_types[i] == FIBER_TYPE_CALIBRATION
]
if not std_indices:
logger.warning(
"CALIBFLUX requested but no fibers with TYPE='C' found. "
"Skipping flux calibration."
)
return
logger.info(
"Flux calibration: %d standard-star fibers (TYPE='C'): %s",
len(std_indices), std_indices,
)
# --- Load resources ---
catalog_path = args.get('CALIBFLUX_CATALOG')
if not catalog_path:
logger.error(
"CALIBFLUX_CATALOG not set. "
"Provide the path to the standard-star photometry CSV."
)
return
catalog = load_standard_star_catalog(catalog_path)
if len(catalog) == 0:
logger.error("Standard-star catalog is empty: %s", catalog_path)
return
library = TemplateLibrary()
filter_curves = load_filter_curves(DEFAULT_BANDS)
mask_regions = load_mask_regions("telluric_default")
instrument_fwhm = args.get('CALIBFLUX_FWHM')
if instrument_fwhm is None:
instrument_fwhm = red_file.get_header_value('SPECFWHM')
if instrument_fwhm is None:
instrument_fwhm = 3.0
logger.warning("Using default FWHM=%.1f Å", instrument_fwhm)
else:
instrument_fwhm = float(instrument_fwhm)
metric = args.get('CALIBFLUX_METRIC', 'chi2')
smooth = args.get('CALIBFLUX_SMOOTH', False)
# --- Compute per-star calibration vectors ---
cal_vectors = []
fiber_table = red_file.read_fiber_table() if red_file.has_fiber_table() else None
for idx, fib_idx in enumerate(std_indices):
# Extract observed spectrum for this fiber
obs_flux = spectra[fib_idx, :]
obs_var = variance[fib_idx, :]
obs_mask = np.isfinite(obs_flux) & (obs_var >= 0) & np.isfinite(obs_var)
obs_spec = Spectrum1D(
wavelength=wavelength.copy(),
flux=obs_flux.copy(),
variance=np.where(obs_var > 0, obs_var, 0.0),
mask=obs_mask,
meta={"fiber_id": fib_idx},
)
# Get star name from fiber table
star_name = ""
if fiber_table is not None:
try:
star_name = str(fiber_table["NAME"][fib_idx]).strip()
except (KeyError, IndexError):
pass
# Match to catalog row (by index for now; positional matching is TODO)
if idx < len(catalog):
row = catalog[idx]
else:
logger.warning(
"More standard fibers (%d) than catalog rows (%d); "
"skipping fiber %d",
len(std_indices), len(catalog), fib_idx,
)
continue
phot = photometry_from_catalog_row(row)
try:
cal_vec = compute_calibration_vector_for_star(
obs_spec, phot, library, filter_curves,
instrument_fwhm_angstrom=instrument_fwhm,
mask_regions=mask_regions,
metric=metric,
star_name=star_name,
fiber_id=fib_idx,
)
cal_vectors.append(cal_vec)
except Exception as exc:
logger.warning(
"Calibration failed for fiber %d (%s): %s",
fib_idx, star_name, exc,
)
continue
if not cal_vectors:
logger.error(
"All standard-star calibrations failed. "
"Skipping flux calibration."
)
return
# --- Combine and apply ---
result = combine_calibration_vectors(
cal_vectors, method="weighted_mean", smooth=smooth,
)
cal_spectra, cal_variance, header_updates = apply_flux_calibration(
spectra, variance, result,
)
# --- Write back ---
if layout == "npix_nfib":
red_file.write_image_data(cal_spectra.T)
red_file.write_variance_data(cal_variance.T)
else:
red_file.write_image_data(cal_spectra)
red_file.write_variance_data(cal_variance)
# Update header
for key, val in header_updates.items():
if key == "HISTORY":
for h in val:
red_file.set_header_value("HISTORY", h)
else:
value, comment = val
red_file.set_header_value(key, value, comment=comment)
logger.info(
"Flux calibration complete: %d standards, RMS=%.4f",
result.summary["n_stars_used"],
result.summary["rms_scatter"],
)
# =====================================================================
# P0 — Implemented Functions
# =====================================================================
def _scrunch(red_filename: str, args: Dict[str, Any]) -> None:
"""Rebin the object frame to a linear wavelength grid using the arc
wavelength solution.
Reads ``WAVEL_FILENAME`` from *args* to locate the calibrated arc
RED file and delegates to :func:`wavecal.scrunch.scrunch_from_arc_id`.
"""
arc_filename = args.get('WAVEL_FILENAME')
if not arc_filename:
logger.warning("WAVEL_FILENAME not set — skipping scrunch")
return
if not Path(arc_filename).exists():
logger.warning("Arc file %s not found — skipping scrunch", arc_filename)
return
scrunch_from_arc_id(red_filename, arc_filename, args, reverse=False)
logger.info("Scrunched %s using arc %s", red_filename, arc_filename)
def _check_nod_shuffle(red_filename: str) -> bool:
"""Return True if the observation used Nod & Shuffle mode.
Checks the ``UTNODSFL`` header keyword. Falls back to False (standard
mode) when the keyword is absent.
"""
with ImageFile(red_filename, mode='READ') as f:
flag = f.get_header_value('UTNODSFL', None)
if flag is not None:
return str(flag).strip().upper() in ('T', 'TRUE', '1', 'Y')
return False
def _delete_pixcal(red_filename: str) -> None:
"""Remove the intermediate PIXCAL HDU if present."""
with ImageFile(red_filename, mode='UPDATE') as f:
if f.delete_hdu('PIXCAL'):
logger.info("Deleted PIXCAL HDU from %s", red_filename)
def _write_reduction_args(red_filename: str, args: Dict[str, Any]) -> None:
"""Persist selected reduction arguments as FITS header keywords.
Writes each arg as ``HIERARCH DRARG <KEY> = <value>`` so the
provenance of the reduction is recorded in the file.
"""
_SKIP_KEYS = {'RAW_FILENAME', 'IMAGE_FILENAME', 'EXTRAC_FILENAME',
'OUTPUT_FILENAME'}
with ImageFile(red_filename, mode='UPDATE') as f:
for key, value in args.items():
if key in _SKIP_KEYS:
continue
hdr_key = f"HIERARCH DRARG {key}"
try:
f.set_header_value(hdr_key, value)
except (ValueError, TypeError):
f.set_header_value(hdr_key, str(value))
def _set_reduced_status(red_filename: str) -> None:
"""Mark the output file as reduced by setting the DRSTATUS keyword."""
with ImageFile(red_filename, mode='UPDATE') as f:
f.set_header_value('DRSTATUS', 'REDUCED', comment='Reduction status')
def _stamp_pipeline_version(red_filename: str) -> None:
"""Write the kspecdr pipeline version into the FITS header."""
from . import __version__
with ImageFile(red_filename, mode='UPDATE') as f:
f.set_header_value('DRPIPVER', __version__,
comment='kspecdr pipeline version')
f.add_history(f"Reduced with kspecdr {__version__}")
# =====================================================================
# P1 — Implemented Functions
# =====================================================================
def _flatfield(red_filename: str, args: Dict[str, Any]) -> None:
"""Divide by fiber flat-field response (P1-1: cmfspec_flatfield).
Reads the master FFLAT RED file and divides the science spectra and
variance in place. Implements the full Taylor-expansion variance
propagation from 2dfdr ``CMFSPEC_FLATFIELD``.
The flat must be in pixel space (un-scrunched); flat-fielding happens
**before** scrunching in the reduction order.
"""
if not args.get('USEFFLAT', True):
with ImageFile(red_filename, mode='UPDATE') as f:
f.add_history('Not divided by fibre flat field')
return
fflat_fname = args.get('FFLAT_FILENAME')
if not fflat_fname or not str(fflat_fname).strip():
logger.error("FFLAT_FILENAME not set — skipping flat-field division")
return
truncflat = args.get('TRUNCFLAT', False)
useflatstart = int(args.get('USEFLATSTART', 1))
useflatend = int(args.get('USEFLATEND', 2048))
with ImageFile(red_filename, mode='UPDATE') as red_f, \
ImageFile(fflat_fname, mode='READ') as flt_f:
obj_img = red_f.read_image_data() # (NFIB, NPIX)
obj_var = red_f.read_variance_data()
flt_img = flt_f.read_image_data().copy()
flt_var = flt_f.read_variance_data().copy()
if obj_img.shape != flt_img.shape:
raise ValueError(
f"Flat shape {flt_img.shape} != object shape {obj_img.shape}. "
"Flat must be reduced with the same TLM as the science frame."
)
# Apply truncation: pixels outside the trusted range divide by 1.0
if truncflat:
p0 = useflatstart - 1 # 1-based → 0-based lower bound (inclusive)
p1 = useflatend # 1-based → 0-based upper bound (exclusive)
if p0 > 0:
flt_img[:, :p0] = 1.0
flt_var[:, :p0] = 0.0
if p1 < obj_img.shape[1]:
flt_img[:, p1:] = 1.0
flt_var[:, p1:] = 0.0
good = (flt_img != 0.0) & np.isfinite(flt_img) & np.isfinite(obj_img)
# Divide image
out_img = np.where(good, obj_img / flt_img, np.nan).astype(np.float32)
# Full Taylor-expansion variance propagation (2dfdr formula):
# Var_out = (1/flat)^2 * Var_obj + (obj_divided/flat^2)^2 * Var_flat
# Note: obj_divided = out_img = obj_img/flat, so second term becomes
# (obj_orig/flat^3)^2 * Var_flat
good_var = good & np.isfinite(obj_var) & np.isfinite(flt_var)
out_var = np.where(
good_var,
(1.0 / flt_img) ** 2 * obj_var
+ (out_img / flt_img ** 2) ** 2 * flt_var,
np.nan,
).astype(np.float32)
red_f.write_image_data(out_img)
red_f.write_variance_data(out_var)
red_f.add_history(f'Divided by fibre flat field {fflat_fname}')
if truncflat:
red_f.add_history(
f'Flat truncated to pixel range {useflatstart}:{useflatend}'
)
logger.info("Applied flat-field from %s to %s", fflat_fname, red_filename)
def _throughput_calibrate(red_filename: str, args: Dict[str, Any]) -> None:
"""Per-fiber throughput correction (P1-2: cmfspec_ftpcal).
Calculates relative fiber throughputs and divides each fiber spectrum
(and variance) by its throughput. Writes a ``THPUT`` ImageHDU (1-D,
length NFIB) to the RED file for downstream diagnostics.
Methods (``TPMETH`` arg, default ``'OFFSKY'``):
- ``'OFF'``: fix all fiber throughputs to 1.0 (no correction; THPUT HDU still written)
- ``'OFFSKY'``: read pre-computed throughputs from ``THPUT_FILENAME``
- ``'SKYLINE(KGB)'``: Karl Glazebrook's sky-line robust-fit algorithm
- ``'MEDIAN'``: simple per-fiber mean, normalized by median
(equivalent to 2dfdr ``UMFSPEC_FTPC``)
Bad throughputs (NaN / outside 0.01–100) are stored as 0.0 in the
THPUT extension and their spectra are set to zero (2dfdr convention).
"""
from astropy.io import fits
from .utils.fiber import get_override_from_args
if not args.get('THRUPUT', True):
with ImageFile(red_filename, mode='UPDATE') as f:
f.add_history('No throughput calibration performed')
return
tpmeth = str(args.get('TPMETH', 'OFFSKY')).upper().strip()
with ImageFile(red_filename, mode='UPDATE') as red_f:
spec = red_f.read_image_data() # (NFIB, NPIX)
var = red_f.read_variance_data()
overrides = get_override_from_args(args)
fiber_types, _ = red_f.read_fiber_types(1000, overrides=overrides)
nfib = spec.shape[0]
fiber_types_arr = np.array(
list(fiber_types[:nfib]), dtype='U1'
)
thput_vec = np.full(nfib, np.nan, dtype=np.float64)
# --- optional external override file ---
use_external = args.get('USETHPTFILE', False)
loaded_external = False
if use_external:
import os
if os.path.exists('THROUGHPUT.fits'):
with fits.open('THROUGHPUT.fits') as hdul:
for hdu in hdul:
if 'THPUT' in hdu.name.upper() and hdu.data is not None:
n = min(len(hdu.data), nfib)
thput_vec[:n] = hdu.data[:n]
loaded_external = True
break
if loaded_external:
logger.info("Throughput: loaded from THROUGHPUT.fits")
tpmeth = '_EXTERNAL'
# --- compute throughputs if not loaded from external file ---
if not loaded_external:
if tpmeth == 'OFFSKY':
thput_fname = str(args.get('THPUT_FILENAME', '')).strip()
if not thput_fname:
logger.warning(
"TPMETH=OFFSKY but THPUT_FILENAME not set — "
"skipping throughput calibration"
)
red_f.add_history('No throughput calibration performed')
return
with fits.open(thput_fname) as hdul:
thput_hdu = None
for hdu in hdul:
if 'THPUT' in hdu.name.upper() and hdu.data is not None:
thput_hdu = hdu
break
if thput_hdu is None:
logger.warning(
"No THPUT extension found in %s — "
"skipping throughput calibration", thput_fname
)
red_f.add_history('No throughput calibration performed')
return
n = min(len(thput_hdu.data), nfib)
thput_vec[:n] = thput_hdu.data[:n]
logger.info("Throughput: loaded from %s (OFFSKY)", thput_fname)
elif tpmeth == 'SKYLINE(KGB)':
thput_vec = _get_thput_kgb(spec, fiber_types_arr)
logger.info("Throughput: KGB sky-line algorithm")
elif tpmeth in ('MEDIAN', 'UMFSPEC'):
thput_vec = _umfspec_ftpc(spec)
logger.info("Throughput: per-fiber median (UMFSPEC)")
elif tpmeth == 'OFF':
thput_vec[:] = 1.0
logger.info("Throughput: OFF (all fibers fixed to 1.0)")
else:
logger.warning(
"Unknown TPMETH '%s' — skipping throughput calibration", tpmeth
)
red_f.add_history('No throughput calibration performed')
return
# --- sanity check ---
bad = ~np.isfinite(thput_vec) | (thput_vec < 0.01) | (thput_vec > 100.0)
thput_vec[bad] = np.nan
n_bad = int(np.sum(bad))
if n_bad:
logger.warning("Throughput: %d fibers have bad/out-of-range values", n_bad)
# --- divide spectra by throughput ---
for i in range(nfib):
if np.isfinite(thput_vec[i]) and thput_vec[i] > 0:
scale = 1.0 / thput_vec[i]
else:
scale = 0.0 # 2dfdr: bad throughput → zero spectrum
spec[i] = np.where(np.isfinite(spec[i]), spec[i] * scale, np.nan)
var[i] = np.where(np.isfinite(var[i]), var[i] * scale**2, np.nan)
red_f.write_image_data(spec.astype(np.float32))
red_f.write_variance_data(var.astype(np.float32))
# Write THPUT extension (bad → 0.0 per 2dfdr convention)
thput_out = np.where(np.isfinite(thput_vec), thput_vec, 0.0).astype(np.float32)
_write_image_hdu(red_f, 'THPUT', thput_out)
hist_map = {
'_EXTERNAL': 'Throughput calibration using THROUGHPUT.fits',
'OFF': 'Throughput calibration disabled (all fibers = 1.0)',
'OFFSKY': f'Throughput calibration using OFFSKY from '
f'{args.get("THPUT_FILENAME", "")}',
'MEDIAN': 'Throughput calibration using per-fiber median (UMFSPEC)',
'UMFSPEC': 'Throughput calibration using per-fiber median (UMFSPEC)',
'SKYLINE(KGB)': 'Throughput calibration using sky lines (KGB method)',
}
red_f.add_history(hist_map.get(tpmeth, f'Throughput calibration ({tpmeth})'))
logger.info("Throughput calibrated %s (method=%s)", red_filename, tpmeth)
def _make_rwss(red_filename: str) -> None:
"""Copy spectra and variance to RWSS/RWSSVAR HDUs before sky subtraction (P1-4).
Saves a snapshot of the current (throughput-corrected, pre-sky) spectra
and their variance so the before/after sky subtraction comparison is
available in the same RED file. Only executed when ``INC_RWSS=True``
(default: False).
"""
with ImageFile(red_filename, mode='UPDATE') as f:
data = f.read_image_data() # (NFIB, NPIX)
var = f.read_variance_data() # same shape
_write_image_hdu(f, 'RWSS', data.astype('float32'))
_write_image_hdu(f, 'RWSSVAR', var.astype('float32'))
logger.info("Saved RWSS (pre-sky) snapshot and variance in %s", red_filename)
def _skysub(red_filename: str, args: Dict[str, Any]) -> None:
"""Sky subtraction using sky fibers (P1-3: SKYSUB).
Identifies sky fibers (``TYPE='S'`` in the FIBRES table), rejects those
with more than 1/8 bad pixels, combines them into a master sky spectrum,
and subtracts it from all fibers. The combined sky and its variance are
written to ``SKY`` and ``SKYVAR`` ImageHDUs.
Combination method is controlled by ``SKYCOMBINE`` arg:
- ``'MEAN'`` (default): straight mean of sky fibers.
- ``'MEDIAN'``: median combination.
- ``'SIGCLIP'``: iterative sigma-clipping mean. Clipping threshold and
maximum iterations are controlled by ``SKYCOMBINE_SIGMA`` (default 3.0)
and ``SKYCOMBINE_ITERS`` (default 5).
Variance propagation: ``Var_out = Var_fib + Var_sky`` where ``Var_sky``
already accounts for the combination of N_sky fibers.
"""
import numpy as np
from pathlib import Path
from .utils.fiber import get_override_from_args
if not args.get('SKYSUB', True):
with ImageFile(red_filename, mode='UPDATE') as f:
f.add_history('No sky subtraction performed')
return
combine_method = str(args.get('SKYCOMBINE', 'MEAN')).upper().strip()
if combine_method not in ('MEAN', 'MEDIAN', 'SIGCLIP'):
logger.warning(
"Unknown SKYCOMBINE '%s' — defaulting to MEAN", combine_method
)
combine_method = 'MEAN'
with ImageFile(red_filename, mode='UPDATE') as red_f:
spec = red_f.read_image_data() # (NFIB, NPIX)
var = red_f.read_variance_data()
overrides = get_override_from_args(args)
fiber_types, _ = red_f.read_fiber_types(1000, overrides=overrides)
nfib = spec.shape[0]
# Identify sky fibers from the FIBRES table
sky_fibs = [
i for i in range(nfib)
if i < len(fiber_types) and fiber_types[i] == 'S'
]
# Fallback: read from skyfibres.dat if no FIBRES table
if not sky_fibs and not red_f.has_fiber_table():
dat = Path('skyfibres.dat')
if dat.exists():
lines = dat.read_text().splitlines()
sky_fibs = [
int(line.strip()) - 1 # 1-based → 0-based
for line in lines
if line.strip().isdigit()
]
logger.info(
"Sky fibers: loaded %d from skyfibres.dat", len(sky_fibs)
)
if not sky_fibs:
logger.warning(
"No sky fibers found — skipping sky subtraction"
)
red_f.add_history('No sky subtraction: no sky fibers found')
return
# FCHECK: reject fibers with > 1/8 bad pixels (2dfdr criterion)
good_sky = [
i for i in sky_fibs
if np.sum(~np.isfinite(spec[i])) < spec.shape[1] / 8
]
if not good_sky:
logger.warning(
"All %d sky fibers have too many bad pixels — "
"skipping sky subtraction", len(sky_fibs)
)
red_f.add_history(
'No sky subtraction: all sky fibers have too many bad pixels'
)
return
logger.info(
"Sky subtraction: %d/%d sky fibers pass quality check (%s)",
len(good_sky), len(sky_fibs), combine_method,
)
# Combine sky fibers into a single sky spectrum and sky variance
sky_stack = spec[good_sky, :] # (N_sky, NPIX)
var_stack = var[good_sky, :]
n_sky = len(good_sky)
if combine_method == 'MEDIAN':
sky_spec = np.nanmedian(sky_stack, axis=0)
# Variance of median ≈ (π/2) * mean_var / N
sky_var = (np.pi / 2.0) * np.nanmean(var_stack, axis=0) / n_sky
elif combine_method == 'SIGCLIP':
sigma = float(args.get('SKYCOMBINE_SIGMA', 3.0))
max_iters = int(args.get('SKYCOMBINE_ITERS', 5))
sky_spec, sky_var = _sigclip_combine(
sky_stack, var_stack, sigma=sigma, max_iters=max_iters
)
logger.info(
"Sky sigma-clipping: sigma=%.1f, max_iters=%d", sigma, max_iters
)
else: # MEAN
sky_spec = np.nanmean(sky_stack, axis=0)
# Error propagation for mean of N independent measurements
sky_var = np.nansum(var_stack, axis=0) / n_sky ** 2
# Subtract sky from every fiber and propagate variance
for fib in range(nfib):
bad = ~np.isfinite(spec[fib]) | ~np.isfinite(sky_spec)
spec[fib, ~bad] -= sky_spec[~bad]
var[fib, ~bad] += sky_var[~bad]
spec[fib, bad] = np.nan
var[fib, bad] = np.nan
red_f.write_image_data(spec.astype(np.float32))
red_f.write_variance_data(var.astype(np.float32))
# Write combined sky to SKY extension (bad pixels → 0.0)
sky_out = np.where(np.isfinite(sky_spec), sky_spec, 0.0).astype(np.float32)
_write_image_hdu(red_f, 'SKY', sky_out)
# Write combined sky variance to SKYVAR (bad pixels → NaN)
sky_var_out = np.where(np.isfinite(sky_spec), sky_var, np.nan).astype(
np.float32
)
_write_image_hdu(red_f, 'SKYVAR', sky_var_out)
red_f.add_history(
f'Sky subtracted using {n_sky} sky fibers ({combine_method})'
)
logger.info(
"Sky subtracted %s using %d fibers", red_filename, len(good_sky)
)
# =====================================================================
# P1 — Private Helpers
# =====================================================================
def _sigclip_combine(
sky_stack: 'np.ndarray',
var_stack: 'np.ndarray',
sigma: float = 3.0,
max_iters: int = 5,
) -> 'tuple[np.ndarray, np.ndarray]':
"""Iterative sigma-clipping mean combination of sky fibers.
For each wavelength pixel, computes the mean and standard deviation of the
contributing sky fibers and masks values that deviate more than ``sigma``
standard deviations from the mean. Iteration continues until no new
pixels are clipped or ``max_iters`` is reached.
Parameters
----------
sky_stack : ndarray, shape (N_sky, NPIX)
Sky fiber spectra.
var_stack : ndarray, shape (N_sky, NPIX)
Corresponding variance arrays.
sigma : float
Clipping threshold in units of standard deviations.
max_iters : int
Maximum number of clipping iterations.
Returns
-------
sky_spec : ndarray, shape (NPIX,)
Sigma-clipped mean sky spectrum.
sky_var : ndarray, shape (NPIX,)
Propagated variance of the combined sky spectrum.
"""
mask = ~np.isfinite(sky_stack) # True = bad/clipped
for _ in range(max_iters):
work = np.where(mask, np.nan, sky_stack)
mean = np.nanmean(work, axis=0) # (NPIX,)
std = np.nanstd(work, axis=0) # (NPIX,)
new_mask = mask | (np.abs(sky_stack - mean[np.newaxis, :]) > sigma * std[np.newaxis, :])
if np.array_equal(new_mask, mask):
break
mask = new_mask
work = np.where(mask, np.nan, sky_stack)
var_work = np.where(mask, np.nan, var_stack)
n_eff = np.sum(~mask, axis=0).clip(min=1).astype(float) # (NPIX,)
sky_spec = np.nanmean(work, axis=0)
sky_var = np.nansum(var_work, axis=0) / n_eff ** 2
return sky_spec, sky_var
def _write_image_hdu(f: ImageFile, name: str, data: 'np.ndarray') -> None:
"""Add or overwrite a named ImageHDU inside an open ImageFile context."""
from astropy.io import fits
name_upper = name.upper()
for idx, hdu in enumerate(f.hdul):
if hdu.name.upper() == name_upper and idx > 0:
hdu.data = data
return
f.hdul.append(fits.ImageHDU(data=data, name=name_upper))
def _umfspec_ftpc(spec: 'np.ndarray') -> 'np.ndarray':
"""Per-fiber mean normalized by the global median (UMFSPEC_FTPC).
This is the 2dfdr ``UMFSPEC_FTPC`` algorithm used as a first-pass
throughput estimate. Values ≤ 0.05 (parked fibers) are set to NaN.
The returned vector is normalized so that the median of good values is 1.
"""
nfib = spec.shape[0]
ftpc = np.full(nfib, np.nan, dtype=np.float64)
for j in range(nfib):
row = spec[j]
good = np.isfinite(row) & (row == row)
if np.sum(good) == 0:
continue
mean_val = float(np.mean(row[good]))
if mean_val <= 0.05: # parked / unused fiber
continue
ftpc[j] = mean_val
valid = np.isfinite(ftpc)
if np.sum(valid) > 0:
med = float(np.median(ftpc[valid]))
if med > 0:
ftpc[valid] /= med
return ftpc
def _subtract_continuum(spec: 'np.ndarray', hw: int = 100) -> 'np.ndarray':
"""Subtract a local median continuum (running window ±hw pixels).
NaN pixels are filled with the global median before filtering, then
restored. Mirrors 2dfdr ``SUBTRACT_MED_FILT`` (box = ±100 pixels).
"""
from scipy.signal import medfilt
nan_mask = ~np.isfinite(spec)
tmp = spec.copy()
if nan_mask.any():
fill = float(np.nanmedian(spec)) if np.any(~nan_mask) else 0.0
tmp[nan_mask] = fill
ksize = 2 * hw + 1
# kernel must be odd and ≤ spectrum length
if ksize > len(tmp):
ksize = len(tmp) if len(tmp) % 2 == 1 else max(1, len(tmp) - 1)
continuum = medfilt(tmp, ksize)
result = spec - continuum
result[nan_mask] = np.nan
return result
def _boxcar1d(spec: 'np.ndarray', width: int = 5) -> 'np.ndarray':
"""Boxcar (top-hat) smoothing of a 1-D spectrum, NaN-aware."""
from scipy.ndimage import uniform_filter1d
nan_mask = ~np.isfinite(spec)
tmp = np.where(nan_mask, 0.0, spec)
cnt = np.where(nan_mask, 0.0, 1.0)
sm_sum = uniform_filter1d(tmp, size=width, mode='nearest') * width
sm_cnt = uniform_filter1d(cnt, size=width, mode='nearest') * width
result = np.where(sm_cnt > 0, sm_sum / sm_cnt, np.nan)
result[nan_mask] = np.nan
return result
def _get_thput_kgb(
spec: 'np.ndarray', fiber_types: 'np.ndarray'
) -> 'np.ndarray':
"""Karl Glazebrook's sky-line throughput algorithm (SKYLINE(KGB)).
Algorithm (mirrors 2dfdr ``GET_THPUT_KGB``):
1. First-pass throughput estimate via ``_umfspec_ftpc``.
2. Build a median sky spectrum from sky fibers, each normalized by the
first-pass estimate.
3. Subtract continuum (±100-px running median) from the sky and each
fiber spectrum, then apply a 5-pixel boxcar smooth.
4. For each P/S fiber, fit a robust line (Siegel slope) of the
continuum-subtracted fiber spectrum vs. the median sky. The slope B
is the throughput.
5. Validate 0.01 < B < 100; normalize the whole vector by its median.
"""
nfib, npix = spec.shape
# Step 1: first-pass estimate
thput_init = _umfspec_ftpc(spec)
# Step 2: identify sky fibers and build median sky
sky_fibs = [
i for i in range(nfib)
if i < len(fiber_types) and fiber_types[i] == 'S'
]
if not sky_fibs:
logger.warning(
"KGB throughput: no sky fibers — falling back to MEDIAN"
)
return thput_init
sky_rows = []
for sf in sky_fibs:
tp = float(thput_init[sf]) if np.isfinite(thput_init[sf]) and thput_init[sf] > 0 else 1.0
sky_rows.append(spec[sf] / tp)
sky_med = np.nanmedian(np.array(sky_rows), axis=0) # (NPIX,)
# Step 3: continuum-subtract and smooth the median sky
sky_cs = _subtract_continuum(sky_med, hw=100)
sky_sm = _boxcar1d(sky_cs, width=5)
# Step 4: fit each P/S fiber
thput_out = np.full(nfib, np.nan, dtype=np.float64)
for i in range(nfib):
ft = fiber_types[i] if i < len(fiber_types) else 'N'
if ft not in ('P', 'S'):
continue
fib_cs = _subtract_continuum(spec[i].copy(), hw=100)
fib_sm = _boxcar1d(fib_cs, width=5)
ok = np.isfinite(fib_sm) & np.isfinite(sky_sm)
if np.sum(ok) < 20:
logger.debug(
"KGB: fiber %d has only %d good pixels — skipping", i, int(np.sum(ok))
)
continue
x = sky_sm[ok]
y = fib_sm[ok]
try:
from scipy.stats import siegelslopes
B = float(siegelslopes(y, x).slope)
except Exception:
# Fallback: OLS slope (no intercept) through origin
denom = float(np.dot(x, x))
B = float(np.dot(x, y) / denom) if denom != 0 else 0.0
if 0.01 < B < 100.0:
thput_out[i] = B
# Step 5: normalize by median
valid = np.isfinite(thput_out)
if np.sum(valid) > 0:
med = float(np.median(thput_out[valid]))
if med > 0:
normed = thput_out[valid] / med
thput_out[valid] = np.where(
(normed > 0.01) & (normed < 100.0), normed, np.nan
)
return thput_out
# =====================================================================
# P2+ — Not Yet Implemented (safe no-ops)
# =====================================================================
def _clean_im(args: Dict[str, Any]) -> None:
"""Clean the IM frame using the OPTEX residual map (not yet implemented)."""
logger.warning("Double-pass CR cleaning not yet implemented — skipping")
def _skylines_recalibration(filename: str, args: Dict[str, Any]) -> None:
"""Fine-tune wavelength solution using sky emission lines (not yet implemented)."""
logger.warning("Skyline recalibration not yet implemented — skipping")
def _skycalib_test(filename: str, args: Dict[str, Any]) -> None:
"""QC test for skyline wavelength calibration (not yet implemented)."""
logger.warning("Skyline calibration test not yet implemented — skipping")
def _super_skysub(red_filename: str, ex_filename: str, args: Dict[str, Any]) -> None:
"""Super-sampled sky subtraction (not yet implemented)."""
logger.warning("Super sky subtraction not yet implemented — skipping")
def _telluric_correct(red_filename: str, args: Dict[str, Any]) -> None:
"""Telluric absorption correction (not yet implemented)."""
logger.warning("Telluric correction not yet implemented — skipping")
def _velocity_correct(red_filename: str, args: Dict[str, Any]) -> None:
"""Heliocentric/barycentric velocity correction (not yet implemented)."""
logger.warning("Velocity correction not yet implemented — skipping")
def _skysub_pca(red_filename: str, args: Dict[str, Any]) -> None:
"""PCA-based sky subtraction (not yet implemented)."""
logger.warning("PCA sky subtraction not yet implemented — skipping")
def _apply_transfer_function(red_filename: str, args: Dict[str, Any]) -> None:
"""Apply an associated transfer function (not yet implemented)."""
logger.warning("Transfer function correction not yet implemented — skipping")
def _dewiggle(red_filename: str, args: Dict[str, Any]) -> None:
"""Remove sinusoidal fringing artifacts (not yet implemented)."""
logger.warning("De-wiggle not yet implemented — skipping")