Source code for kspecdr.extract.make_ex

"""
Extraction Routines for KSPEC.

This module implements the extraction of spectra from image data using tramline maps,
converting the 2dfdr `MAKE_EX` and related subroutines.
"""

import logging
import sys
import numpy as np
from typing import Dict, Any, Optional, Tuple, Union
from pathlib import Path
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
import warnings

from ..io.image import ImageFile
from ..tlm.make_tlm import predict_wavelength
from ..utils.fiber import get_override_from_args

logger = logging.getLogger(__name__)

# Constants
MAX_NFIBRES = 1000
VAL__BADR = np.nan  # Using NaN for bad values in Python


[docs] def make_ex(args: Dict[str, Any]) -> None: """ Main driver for extraction process. Replaces 2dfdr SUBROUTINE MAKE_EX. Parameters ---------- args : dict Dictionary containing arguments: - IMAGE_FILENAME: Input image file - EXTRAC_FILENAME: Output extracted file - TLMAP_FILENAME: Tramline map file - WTSCHEME: Weighting scheme (optional) """ im_fname = args.get("IMAGE_FILENAME") ex_fname = args.get("EXTRAC_FILENAME") tlm_fname = args.get("TLMAP_FILENAME") wtscheme = args.get("WTSCHEME", "STND") if not im_fname or not ex_fname or not tlm_fname: raise ValueError("Missing required filenames (IMAGE, EXTRAC, or TLMAP)") # Check if TLM exists if not Path(tlm_fname).exists(): raise FileNotFoundError(f"Tramline map file not found: {tlm_fname}") logger.info(f"Extracting {im_fname} -> {ex_fname} using TLM {tlm_fname}") # Call the main extraction routine make_ex_from_im(im_fname, tlm_fname, ex_fname, wtscheme, args)
# TODO: Handle Stochastic copies if needed (NSTOCHIM) - skipping for now as it's advanced usage
[docs] def make_ex_from_im( im_fname: str, tlm_fname: str, ex_fname: str, wtscheme: str, args: Dict[str, Any] ) -> None: """ Process image file to produce extracted spectra. Replaces 2dfdr SUBROUTINE MAKE_EX_FROM_IM. Parameters ---------- im_fname : str Input image filename tlm_fname : str Tramline map filename ex_fname : str Output extracted filename wtscheme : str Weighting scheme args : dict Additional arguments """ # 1. Get Extraction Method operat = args.get("EXTR_OPERATION", "TRAM").upper() logger.info(f"Extraction Method: {operat}") valid_methods = [ "FIT", "TRAM", "NEWTRAM", "GAUSS", "OPTEX", "CLOPTEX", "SMCOPTEX", "SUM", ] if operat not in valid_methods: raise ValueError(f"Method must be one of {valid_methods}") if operat == "FIT": raise NotImplementedError("FIT method does NOT work (legacy status)") # 2. Open Image File with ImageFile(im_fname, mode="READ") as im_file: im_class = im_file.get_header_value("CLASS", "UNKNOWN") instrument_code = im_file.get_instrument_code() img_data = im_file.read_image_data() var_data = im_file.read_variance_data() fib_tabl = im_file.read_fiber_table() # 3. Read Tramline Map with ImageFile(tlm_fname, mode="READ") as tlm_file: tlm_data = tlm_file.read_image_data() # TLM is (NFIB, NSPEC) - Horizontal nfib, nspec_tlm = tlm_data.shape # Read fiber types overrides = get_override_from_args(args) fiber_types, _ = im_file.read_fiber_types(MAX_NFIBRES, overrides=overrides) # Note: 2dfdr reads fiber types from IM file usually, as TLM might not have them updated? # Get MWIDTH from TLM header (Median FWHM) mwidth = float(tlm_file.get_header_value("MWIDTH", 2.0)) # Read Wavelength data if available wave_data = tlm_file.read_wave_data() # Verify TLM header matches IM header; recompute WAVELA if mismatch. def _get_hdr_float(img, key): value = img.get_header_value(key, None) if value is None: return None try: return float(value) except (TypeError, ValueError): return None im_lambdac = _get_hdr_float(im_file, "LAMBDAC") im_dispers = _get_hdr_float(im_file, "DISPERS") tlm_lambdac = _get_hdr_float(tlm_file, "LAMBDAC") tlm_dispers = _get_hdr_float(tlm_file, "DISPERS") im_gratid = im_file.get_header_value("GRATID", None) im_gratl = im_file.get_header_value("GRATLPMM", None) tlm_gratid = tlm_file.get_header_value("GRATID", None) tlm_gratl = tlm_file.get_header_value("GRATLPMM", None) def _float_mismatch(a, b, tol=1e-3): if a is None or b is None: return True return abs(a - b) > tol grat_mismatch = ( im_gratid is not None and tlm_gratid is not None and im_gratid != tlm_gratid ) or ( im_gratl is not None and tlm_gratl is not None and im_gratl != tlm_gratl ) wave_mismatch = _float_mismatch(im_lambdac, tlm_lambdac) or _float_mismatch( im_dispers, tlm_dispers ) if ( grat_mismatch or wave_mismatch or (tlm_lambdac is None and tlm_dispers is None) ): logger.warning( "TLM header mismatch with IM header; recomputing WAVELA from IM." ) wave_data = predict_wavelength(im_file, tlm_data.T, args).T # 4. Apply TLM Shift (Shift-Rotate-Tweak) tlm_shift = float(args.get("TLM_SHIFT", 0.0)) if tlm_shift != 0.0: logger.info(f"Shifting Tramline Map by {tlm_shift} pixels") tlm_data += tlm_shift # 5. Background / Scattered Light Subtraction # (Placeholder: SCATSUB) scatsub = args.get("SCATSUB", "NONE") if scatsub != "NONE": # TODO: implement scattered light subtraction logger.warning( f"Scattered light subtraction '{scatsub}' requested but not implemented yet." ) # 6. Standardize dimensions nspat, nspec = img_data.shape tlm_data_T = tlm_data.T # Now (NSPEC, NFIB) # Initialize Output Arrays # Output extracted data: (NSPEC, NFIB) ex_img = np.zeros((nspec, nfib), dtype=np.float32) ex_var = np.zeros((nspec, nfib), dtype=np.float32) # 7. Perform Extraction if operat in ["TRAM", "SUM", "NEWTRAM"]: # Simple Summing Extraction # Get width from args or use default width = float(args.get("SUM_WIDTH", 10.0)) # Default from SUMEXTR # Or from MWIDTH if TRAM? if operat == "TRAM": # TRAM usually uses tramlines. SUMEXTR uses WIDTH. # 2dfdr: IF (OPERAT=='TRAM') CALL UMFIM_TRMEXTR... # ELSE CALL SUMEXTR... # We will implement SUMEXTR logic here as requested "simple summing" # TODO: Implement TRAM extraction pass logger.info(f"Performing SUM extraction with width={width}") sum_extract( nspat, nspec, img_data, var_data, ex_img, ex_var, nfib, tlm_data_T, width ) elif operat == "GAUSS": # Placeholder for GAUSS logger.warning( "GAUSS extraction not fully implemented. Using simplified fallback or raising error." ) raise NotImplementedError("GAUSS extraction not implemented") elif operat in ["OPTEX", "SMCOPTEX"]: # Placeholder for Optimal Extraction raise NotImplementedError("Optimal extraction not implemented") else: raise ValueError(f"Unknown operation: {operat}") # 8. Post-Processing # Handle Bad Fibers (Zero them out) # 2dfdr: UMFIM_ZERO(..., TYP) # Zero out 'F' (Guide), 'N' (Not used), 'U' (Unallocated) for fib in range(nfib): ftype = fiber_types[fib] if ftype in ["F", "N", "U"]: ex_img[:, fib] = 0.0 ex_var[:, fib] = 0.0 # 9. Write Output # Create output file from copy of image file (to preserve headers) # In kspecdr, we might create a new file or copy. # Using ImageFile's save_as or creating new HDUList. # We need to reshape/transpose back to FITS convention (NFIB, NSPEC) for writing? # If ex_img is (NSPEC, NFIB), FITS expects (NAXIS2, NAXIS1) -> (NFIB, NSPEC). ex_img_out = ex_img.T ex_var_out = ex_var.T from astropy.io import fits # Read original header with fits.open(im_fname) as hdul_src: header = hdul_src[0].header.copy() # Update Header header["HISTORY"] = f"Extracted using {operat}" # Set Axes Labels header["CTYPE1"] = "Wavelength" header["CUNIT1"] = "Angstroms" header["CTYPE2"] = "Fibre Number" # Create HDUs hdu_data = fits.PrimaryHDU(data=ex_img_out, header=header) hdu_var = fits.ImageHDU(data=ex_var_out, name="VARIANCE") hdul_out = fits.HDUList([hdu_data, hdu_var]) # Copy WAVELA if available (from TLM usually) if wave_data is not None: # wave_data is (NFIB, NSPEC) if read from standard TLM. # Output is (NFIB, NSPEC). # So no transpose needed if TLM is already in output format. # However, check internal consistency. # If read_wave_data returned data.shape == (ny, nx) == (nfib, nx_tlm) # And output is (nfib, nspec). # It matches. hdu_wave = fits.ImageHDU(data=wave_data, name="WAVELA") hdul_out.append(hdu_wave) # Copy FIBRES if available (from IM file) if fib_tabl is not None: hdu_fibres = fits.BinTableHDU(data=fib_tabl, name="FIBRES") hdul_out.append(hdu_fibres) hdul_out.writeto(ex_fname, overwrite=True) hdul_out.close() logger.info(f"Written extracted file: {ex_fname}")
[docs] def sum_extract( nspat: int, nspec: int, indat: np.ndarray, invar: np.ndarray, outdat: np.ndarray, outvar: np.ndarray, nfib: int, tlmap: np.ndarray, width: float, ) -> None: """ Perform simple summing extraction. Replaces 2dfdr SUBROUTINE SUMEXTR. Parameters ---------- nspat : int Number of spatial pixels (rows in indat) nspec : int Number of spectral pixels (cols in indat) indat : np.ndarray Input image (NSPAT, NSPEC) - Horizontal Dispersion invar : np.ndarray Input variance (NSPAT, NSPEC) - Horizontal Dispersion outdat : np.ndarray Output spectra (NSPEC, NFIB) - Updated in place outvar : np.ndarray Output variance (NSPEC, NFIB) - Updated in place nfib : int Number of fibers tlmap : np.ndarray Tramline map (NSPEC, NFIB) width : float Width of extraction window """ # Loop over spectral pixels (columns) with logging_redirect_tqdm(): for j in tqdm(range(nspec), desc="Summing extraction", unit="pixel"): # Loop over fibers for fibre in range(nfib): # Get center of fiber profile from TLM # 2dfdr adds TRAMLINE_OFFSET (usually 0.0 or 0.5 depending on convention) # Python/Numpy 0-based vs Fortran 1-based. # If TLM is 1-based (from Fortran 2dfdr), we might need to adjust. # Assuming TLM data is already converted to 0-based pixel coordinates? # Or if it's FITS pixel coordinates (1-based), we need to subtract 1? # Let's assume TLM values are 0-based pixel coordinates for now (standard python practice), # or FITS convention (1-based). # If FITS convention: center = tlm_val - 1.0 # Let's stick to the raw value for now, assuming TLM matches image grid. tlm_pt = tlmap[j, fibre] # 2dfdr: TLMPT = TLMAP(J,FIBRE) + TRAMLINE_OFFSET # If we assume 0-based, no offset needed if TLM is correct. tlow = tlm_pt - width / 2.0 thigh = tlm_pt + width / 2.0 # Convert to integer indices # 2dfdr: ILOW = INT(TLOW) + 1 (1-based) # Python: ilow = int(floor(tlow))? # Pixel i covers [i-0.5, i+0.5]? Or [i, i+1]? # 2dfdr logic implies pixel centers. # Let's use standard partial pixel integration. ilow = int(np.floor(tlow)) ihigh = int(np.floor(thigh)) # Clip to image boundaries ilow = max(0, ilow) ihigh = min(nspat - 1, ihigh) tot_pix = 0.0 tot_var = 0.0 # Check for bad pixels in the full pixels range # Range is inclusive of ilow, inclusive of ihigh? # 2dfdr: DO PIX=ILOW,IHIGH (inclusive) # But wait, logic for partials: # IF ILOW > 1 ... ADD PARTIAL LOW # IF IHIGH < NSPAT ... ADD PARTIAL HIGH # The Loop is for "whole pixels" fully inside? # 2dfdr: ILOW = INT(TLOW) + 1. # e.g. TLOW = 5.5 -> ILOW = 6. # Pixel 6 is fully inside? # # Let's implement robust partial pixel summation. # Range [tlow, thigh]. # Integrate flux from tlow to thigh. # Assuming pixels are boxcars centered at integer coordinates? # Or centered at int+0.5? # 2dfdr convention: Pixel coordinates are usually 0.5 to N+0.5? # # Let's simplify: # Sum pixels from ceil(tlow) to floor(thigh). # Add partial fraction of floor(tlow) and ceil(thigh). # 2dfdr SUMEXTR Implementation translated: # Bounds check if ihigh < ilow: # Window is too small or out of bounds outdat[j, fibre] = VAL__BADR outvar[j, fibre] = VAL__BADR continue bad_pixel = False # Sum whole pixels (or mostly whole) # In 2dfdr, loop is ILOW to IHIGH. # 2dfdr ILOW calculation: INT(TLOW) + 1. # If TLOW=5.1, ILOW=6. Pixel 6 is included. # But Pixel 5 is partial. # So loop covers "inner" pixels. # Python equivalent: # range(ilow_idx, ihigh_idx + 1) # where ilow_idx is index of first FULL pixel > tlow. # ihigh_idx is index of last FULL pixel < thigh. start_full = int(np.ceil(tlow)) # e.g. 5.1 -> 6 end_full = int(np.floor(thigh)) # e.g. 9.9 -> 9 # Sum Full Pixels # Note: 2dfdr logic for ILOW/IHIGH is slightly different, let's stick to first principles # or exact translation. current_flux = 0.0 current_var = 0.0 # Iterate through all pixels touched p_start = int(np.floor(tlow)) p_end = int(np.floor(thigh)) # Note: thigh is upper bound # Cap at image edges p_start = max(0, p_start) p_end = min(nspat - 1, p_end) for pix in range(p_start, p_end + 1): # Calculate fraction of pixel included # Pixel covers [pix, pix+1] (assuming 0-based corner? Or center?) # If center is pix, range is [pix-0.5, pix+0.5]. # 2dfdr usually assumes pixel centers are integers 1, 2, ... # Let's assume standard FITS: pixel centers are 1.0, 2.0. # Python 0-based: centers are 0.0, 1.0? # Actually, typically [x, x+1] is the range for pixel x. # Let's assume pixel `pix` covers spatial range [pix, pix+1]. pix_min = float(pix) pix_max = float(pix) + 1.0 # Intersection of [pix_min, pix_max] and [tlow, thigh] overlap_min = max(pix_min, tlow) overlap_max = min(pix_max, thigh) if overlap_max > overlap_min: fraction = overlap_max - overlap_min val = indat[pix, j] var = invar[pix, j] if np.isnan(val) or np.isnan(var): bad_pixel = True break current_flux += val * fraction current_var += var * fraction # Linear variance scaling? # 2dfdr: TOTVAR = TOTVAR+INVAR(ILOW-1,J)*PART # Yes, it scales variance by fraction? # Actually variance of (A*x) is A^2 * Var(x). # 2dfdr seems to just use fraction? # "TOTVAR = TOTVAR+INVAR(ILOW-1,J)*PART" # This implies Var(fraction * Pixel) = fraction * Var(Pixel). # This is correct if we are summing 'fraction' of the Poisson counts? # No, strictly Var(c*X) = c^2 * Var(X). # But 2dfdr does linear. Let's replicate 2dfdr behavior for now. # Warning: 2dfdr might be doing 'counts' scaling. if bad_pixel: outdat[j, fibre] = VAL__BADR outvar[j, fibre] = VAL__BADR else: outdat[j, fibre] = current_flux outvar[j, fibre] = current_var