Source code for kspecdr.tlm.make_tlm

"""
Tramline map generation using astropy-based I/O.

This module provides functions for generating tramline maps from FITS images,
replacing the Fortran TDFIO functions with astropy-based equivalents.

TODO: check the usage of the arguments in the function calls.
"""

import numpy as np
import sys
import logging
from typing import Tuple, Optional, Dict, Any, Sequence, Union
from scipy import ndimage
from scipy.optimize import linear_sum_assignment
from scipy.optimize import curve_fit
from scipy.signal import find_peaks
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, fcluster

from kspecdr.tracking import multi_target_tracking
from kspecdr.io.image import ImageFile
from .match_fibers import (
    taipan_nominal_fibpos,
    match_fibers_taipan,
    match_fibers_isoplane,
)
from ..constants import (
    INST_GENERIC,
    INST_2DF,
    INST_6DF,
    INST_AAOMEGA_2DF,
    INST_HERMES,
    INST_AAOMEGA_SAMI,
    INST_TAIPAN,
    INST_AAOMEGA_KOALA,
    INST_AAOMEGA_IFU,
    INST_SPECTOR_HECTOR,
    INST_AAOMEGA_HECTOR,
    INST_ISOPLANE,
    MAX__NFIBRES,
)
from ..utils.args import init_args
from ..utils.fiber import get_override_from_args

logger = logging.getLogger(__name__)


[docs] def make_tlm(args: Dict[str, Any]) -> None: """ Generate a tramline map from an image file. This function replaces the Fortran MAKE_TLM subroutine. Parameters ---------- args : dict Dictionary containing method arguments including: - 'IMAGE_FILENAME': Input image file path - 'TLMAP_FILENAME': Output tramline map file path (optional) """ args = init_args(args) im_fname = args.get("IMAGE_FILENAME") if not im_fname: raise ValueError("IMAGE_FILENAME is required") tlm_fname = args.get("TLMAP_FILENAME") if not tlm_fname: tlm_fname = im_fname.replace("_im.fits", "_tlm.fits") logger.info(f"Generating tramline map from {im_fname}") with ImageFile(im_fname) as im_file: make_tlm_from_im(im_file, tlm_fname, args) logger.info(f"Generated tramline map: {tlm_fname}")
[docs] def make_tlm_from_im(im_file: ImageFile, tlm_fname: str, args: Dict[str, Any]) -> None: """ Generate tramline map from an opened image file. This function replaces the Fortran MAKE_TLM_FROM_IM subroutine. Parameters ---------- im_file : ImageFile Opened image file handler tlm_fname : str Output tramline map filename args : dict Method arguments """ instrument_code = im_file.get_instrument_code() logger.info(f"Instrument code: {instrument_code}") make_tlm_other(im_file, tlm_fname, instrument_code, args)
[docs] def make_tlm_other( im_file: ImageFile, tlm_fname: str, instrument_code: int, args: Dict[str, Any] ) -> None: """ Generate tramline map for non-2DF instruments. This function replaces the Fortran MAKE_TLM_OTHER subroutine. Parameters ---------- im_file : ImageFile Opened image file handler tlm_fname : str Output tramline map filename instrument_code : int Instrument code args : dict Method arguments """ logger.info("Starting tramline map generation for non-2DF instrument") # Step 0: Pre-amble - Read image data and get instrument information img_data, var_data, fibre_types = read_instrument_data( im_file, instrument_code, args ) # Extract SPECTID from header and add to args for matching spectid = im_file.get_header_value("SPECTID", "RED") args["SPECTID"] = spectid # Step 1: Set instrument-specific parameters order, pk_search_method, do_distortion, sparse_fibs, experimental, qad_pksearch = ( set_instrument_specific_params(instrument_code, args) ) logger.debug( f"order: {order}, pk_search_method: {pk_search_method}, do_distortion: {do_distortion}, sparse_fibs: {sparse_fibs}, experimental: {experimental}, qad_pksearch: {qad_pksearch}" ) # Step 2: Convert fibre types to trace status fibre_has_trace = convert_fibre_types_to_trace_status( instrument_code, fibre_types, len(fibre_types) ) # Step 3: Count fibre types n_officially_inuse = np.sum(fibre_has_trace == "YES") n_potentially_able = np.sum(fibre_has_trace == "MAYBE") n_officially_dead = np.sum(fibre_has_trace == "NO") logger.info(f"Fibres officially in use: {n_officially_inuse}") logger.info(f"Fibres potentially able: {n_potentially_able}") logger.info(f"Fibres officially dead: {n_officially_dead}") # Step 4: Find fiber traces across the image # Standardized to Horizontal Dispersion: (Spatial, Spectral) = (rows, cols) nspat, nspec = img_data.shape max_ntraces = len(fibre_types) nf = len(fibre_types) logger.info(f"Max number of traces: {max_ntraces}") logger.info(f"Image dimensions: nspec={nspec}, nspat={nspat}") ntraces, traces, spat_slice, pk_posn = detect_traces( img_data, nspec, nspat, max_ntraces, nf, order, sparse_fibs, experimental, pk_search_method, do_distortion, ) logger.info(f"Found {ntraces} traces across the image") # Step 5: Match located traces to fibre index match_vector, modelled_fibre_positions = match_traces_to_fibres( instrument_code, traces, fibre_types, pk_posn, args ) # Step 6: Convert identified traces to fibre tramline map array tramline_map = convert_traces_to_tramline_map( traces, match_vector, len(fibre_types) ) # Step 7: Interpolate missing fibre traces if instrument_code == INST_TAIPAN: interpolate_tramlines_taipan( tramline_map, match_vector, modelled_fibre_positions ) else: interpolate_tramlines( tramline_map, match_vector, get_fibre_separation(instrument_code) ) # Step 8: Write tramline data to output file write_tramline_data(tlm_fname, tramline_map, instrument_code, im_file) # Step 9: Calculate and write wavelength data (if not 2DF) if instrument_code != INST_2DF: wavelength_data = predict_wavelength(im_file, tramline_map, args) write_wavelength_data(tlm_fname, wavelength_data) logger.info("Tramline map generation completed")
[docs] def read_instrument_data( im_file: ImageFile, instrument_code: int, args: Dict[str, Any] ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Read instrument data from image file. Parameters ---------- im_file : ImageFile Opened image file handler instrument_code : int Instrument code Returns ------- tuple (img_data, var_data, fibre_types) """ img_data = im_file.read_image_data() var_data = im_file.read_variance_data() # overrides = get_override_from_args(im_file.hdul[0].header.get("ARGS", {})) overrides = get_override_from_args(args) fibre_types, nf = im_file.read_fiber_types(MAX__NFIBRES, overrides=overrides) return img_data, var_data, fibre_types
[docs] def set_instrument_specific_params( instrument_code: int, args: Dict[str, Any] ) -> Tuple[int, int, bool, bool, bool, bool]: """ Set instrument-specific parameters. Parameters ---------- instrument_code : int Instrument code args : dict Method arguments Returns ------- tuple (order, pk_search_method, do_distortion, sparse_fibs, experimental, qad_pksearch) """ # Get arguments with defaults sparse_fibs = args.get("SPARSE_FIBS", False) experimental = args.get("TLM_FIT_RES", False) qad_pksearch = args.get("QAD_PKSEARCH", False) # Set polynomial order based on instrument order = 4 # Default if instrument_code == INST_6DF: order = 2 elif instrument_code == INST_TAIPAN: order = 2 elif instrument_code == INST_AAOMEGA_IFU: order = 2 elif instrument_code == INST_AAOMEGA_KOALA: order = 2 elif instrument_code == INST_SPECTOR_HECTOR: order = 6 elif instrument_code == INST_AAOMEGA_HECTOR: order = 4 elif instrument_code == INST_ISOPLANE: order = 2 # Set peak search method pk_search_method = 0 # Default (emergence watershed) if instrument_code == INST_AAOMEGA_KOALA or instrument_code == INST_AAOMEGA_IFU: pk_search_method = 1 # Find all local peaks elif instrument_code == INST_TAIPAN: pk_search_method = 2 # Wavelet convolution elif instrument_code == INST_ISOPLANE: pk_search_method = 2 elif instrument_code == INST_SPECTOR_HECTOR: pk_search_method = 0 elif instrument_code == INST_AAOMEGA_HECTOR: pk_search_method = 0 # Override with argument if specified if qad_pksearch: pk_search_method = 1 logger.info("OVERRIDE PEAK SEARCH METHOD TO QAD") # Set distortion modelling flag do_distortion = True if instrument_code == INST_SPECTOR_HECTOR: do_distortion = False elif instrument_code == INST_AAOMEGA_HECTOR: do_distortion = False return ( order, pk_search_method, do_distortion, sparse_fibs, experimental, qad_pksearch, )
[docs] def convert_fibre_types_to_trace_status( instrument_code: int, fibre_types: np.ndarray, nf: int ) -> np.ndarray: """ Convert fibre types to trace status. Parameters ---------- instrument_code : int Instrument code fibre_types : np.ndarray Array of fibre types nf : int Number of fibres Returns ------- np.ndarray Array of trace status ('YES', 'NO', 'MAYBE') """ fibre_has_trace = np.full(nf, "NO", dtype="U5") for i in range(nf): fib_type = fibre_types[i] # Map fibre types to trace status if fib_type in ["P", "S"]: # Program, Sky fibre_has_trace[i] = "YES" elif fib_type in ["F", "D"]: # Fiducial, Dead fibre_has_trace[i] = "NO" elif fib_type in ["N", "U"]: # Not used, Unused fibre_has_trace[i] = "MAYBE" else: fibre_has_trace[i] = "NO" return fibre_has_trace
[docs] def detect_traces( img_data: np.ndarray, nspec: int, nspat: int, max_ntraces: int, nf: int, order: int = 4, sparse_fibs: bool = False, experimental: bool = False, pk_search_mthd: int = 0, dodist: bool = True, ) -> Tuple[int, np.ndarray, np.ndarray, np.ndarray]: """ Detect fiber traces across an image. This function examines IMG_DATA(NSPAT, NSPEC) for identifiable fibre traces and creates a traces pathlist array. It returns a representation of a spatial profile slice and peak list that can be used for other analysis. This function replaces the Fortran LOCATE_TRACES call. Parameters ---------- img_data : np.ndarray Image data of shape (nspat, nspec) - Horizontal Dispersion nspec, nspat : int Dimensions of the image (spectral, spatial) max_ntraces : int Maximum number of traces to return nf : int Number of fibers in instrument order : int, optional Order of polynomial fitting (default: 4) sparse_fibs : bool, optional If there is only a sparse number of fibers (default: False) experimental : bool, optional If to use experimental restrictions for blurred data (default: False) pk_search_mthd : int, optional Peak search method: 0=standard, 1=local peaks, 2=wavelet (default: 0) dodist : bool, optional Whether to do distortion modeling (default: True) Returns ------- ntraces : int Number of traces found. tracea : np.ndarray Trace array of shape (nspec, max_ntraces). rep_slice : np.ndarray Representation profile slice of shape (nspat,). rep_pkpos : np.ndarray Representation slice peak list of shape (nspat,). """ # Heuristic parameters (from Fortran code) STEP = 50 # Step size for column sweep HWID = 10 # Half width for averaging around columns MAXD = 4.0 # Maximum displacement expected for fiber traces # Initialize arrays tracea = np.zeros((nspec, max_ntraces)) rep_slice = np.zeros(nspat) rep_pkpos = np.zeros(nspat) # Calculate number of steps nsteps = (nspec - 1) // STEP + 1 # Arrays to store peak information pk_grid = np.zeros((nsteps, max_ntraces)) trace_pts = np.zeros((max_ntraces, nsteps)) # Step 1: Sweep the image to find fiber peaks in selected columns logger.info("Sweeping image for signs of fibre traces...") # Vectorized column processing col_indices = np.arange(0, nspec, STEP) if col_indices[-1] >= nspec: col_indices = col_indices[:-1] for stepno, colno in enumerate(col_indices): # Progress feedback perc = float(colno) / float(nspec) * 100.0 logger.info(f"Processing column {colno}/{nspec} ({perc:.1f}%)") # Create a vector slice by averaging around column colno (vectorized) col_start = max(0, colno - HWID) col_end = min(nspec, colno + HWID + 1) # Extract column range and average # Horizontal Dispersion: Slice along spectral axis (columns, axis=1) col_range = img_data[:, col_start:col_end] valid_mask = ~np.isnan(col_range) # Compute average along spectral axis (axis=1), handling NaN values col_data = np.zeros(nspat) ngood = np.sum(valid_mask, axis=1) valid_cols = ngood > 0 if np.any(valid_cols): col_data[valid_cols] = ( np.nansum(col_range[valid_cols, :], axis=1) / ngood[valid_cols] ) # Locate fiber peaks in this slice if pk_search_mthd == 0: # Standard peak finding using scipy with adaptive height threshold max_val = np.nanmax(col_data) if max_val > 0: height_threshold = 0.1 * max_val peaks, properties = find_peaks( col_data, height=height_threshold, distance=3 ) # Select the highest peaks until the max_ntraces is reached if len(peaks) > max_ntraces: peak_heights = properties["peak_heights"] sorted_indices = np.argsort(peak_heights)[::-1] peaks = peaks[sorted_indices[:max_ntraces]] else: peaks = np.array([], dtype=int) elif pk_search_mthd == 1: # Quick and dirty method - find all local maxima peaks, _ = find_peaks(col_data, distance=2) # Filter peaks below 10% of maximum if len(peaks) > 0: max_height = np.max(col_data[peaks]) mask = col_data[peaks] >= 0.1 * max_height peaks = peaks[mask] # Select the highest peaks until the max_ntraces is reached if len(peaks) > max_ntraces: peak_heights = col_data[peaks] sorted_indices = np.argsort(peak_heights)[::-1] peaks = peaks[sorted_indices[:max_ntraces]] elif pk_search_mthd == 2: # Wavelet convolution method # peaks = _wavelet_peak_detection_old(col_data, scale=2.0, max_peaks=max_ntraces) peaks = _wavelet_peak_detection(col_data, max_peaks=max_ntraces) # Convert peak positions to indices if len(peaks) > 0: peaks = peaks.astype(int) else: peaks = np.array([], dtype=int) else: # Default to standard method max_val = np.nanmax(col_data) if max_val > 0: height_threshold = 0.1 * max_val peaks, properties = find_peaks( col_data, height=height_threshold, distance=3 ) # Select the highest peaks until the max_ntraces is reached if len(peaks) > max_ntraces: peak_heights = properties["peak_heights"] sorted_indices = np.argsort(peak_heights)[::-1] peaks = peaks[sorted_indices[:max_ntraces]] else: peaks = np.array([], dtype=int) # Store peak information npks = len(peaks) logger.debug(f"npks: {npks}") if npks > 0: p_pks = peaks.astype(float) pk_grid[stepno, :npks] = p_pks # Store central slice data for representation if stepno == nsteps // 2: rep_slice = col_data.copy() rep_pkpos[:npks] = p_pks logger.debug(f"pk_grid shape: {pk_grid.shape}") # Step 2: Link peak locations into fiber traces logger.info("Linking trace data to build fiber Tramline Map...") # linking algorithm using MTT approach ntraces, trace_pts = multi_target_tracking(pk_grid, nsteps, max_ntraces, MAXD) logger.info(f"Found {ntraces} traces across the image") # Step 3: Interpolate across linked points of each identified trace logger.info("Interpolating trace paths...") x_fit = np.arange(1, nspec + 1) - 0.5 for idx in range(ntraces): # Get valid points for this trace valid_mask = trace_pts[idx, :] > 0 if not np.any(valid_mask): continue x_valid = np.arange(1, nspec + 1, STEP)[valid_mask] - 0.5 y_valid = trace_pts[idx, valid_mask] if len(x_valid) < 3: # Need at least 3 points for polynomial fitting continue # Fit polynomial to trace points try: if order > 4: # Use higher order polynomial with regularization poly_order = min(order, len(x_valid) - 1) coeffs = np.polyfit(x_valid, y_valid, poly_order) else: # Use quadratic fit poly_order = min(2, len(x_valid) - 1) coeffs = np.polyfit(x_valid, y_valid, poly_order) # Evaluate polynomial across full x range y_fit = np.polyval(coeffs, x_fit) tracea[:, idx] = y_fit except (np.RankWarning, ValueError): # If fitting fails, use linear interpolation tracea[:, idx] = np.interp(x_fit, x_valid, y_valid) # Update ntraces to actual number of valid traces ntraces = np.sum([np.any(tracea[:, i] != 0) for i in range(max_ntraces)]) logger.info(f"Final number of traces: {ntraces}") return ntraces, tracea, rep_slice, rep_pkpos
def _link_peaks_to_traces( pk_grid: np.ndarray, nsteps: int, max_ntraces: int, max_displacement: float ) -> Tuple[int, np.ndarray]: """ Link peak locations into fiber traces using clustering approach. This function implements a simplified version of the Fortran PK_GRID2TRACES algorithm, using hierarchical clustering instead of Multi-Target Tracking. Parameters ---------- pk_grid : np.ndarray Peak grid array nsteps : int Number of steps max_ntraces : int Maximum number of traces max_displacement : float Maximum displacement between consecutive peaks Returns ------- tuple (ntraces, trace_pts) ntraces : int Number of traces found trace_pts : np.ndarray Trace points array of shape (max_ntraces, nsteps) """ # Collect all valid peaks with their positions peak_positions = [] peak_steps = [] for stepno in range(nsteps): peaks_in_step = pk_grid[stepno, :] valid_peaks = peaks_in_step[peaks_in_step > 0] for peak in valid_peaks: peak_positions.append(peak) peak_steps.append(stepno) if len(peak_positions) == 0: return 0, np.zeros((max_ntraces, nsteps)) # Convert to numpy arrays peak_positions = np.array(peak_positions) peak_steps = np.array(peak_steps) # Create feature matrix for clustering # Features: [position, step_number] - similar to Fortran's temporal sequence features = np.column_stack([peak_positions, peak_steps]) # Calculate distance matrix distances = pdist(features, metric="euclidean") # Perform hierarchical clustering (single linkage for continuity) linkage_matrix = linkage(distances, method="single") # Determine number of clusters (traces) using distance threshold cluster_labels = fcluster(linkage_matrix, max_displacement, criterion="distance") unique_clusters = np.unique(cluster_labels) n_clusters = len(unique_clusters) logger.debug(f"Number of clusters: {n_clusters}") n_clusters = min(n_clusters, max_ntraces) logger.debug(f"Number of clusters after min: {n_clusters}") # Create trace points array trace_pts = np.zeros((max_ntraces, nsteps)) # Step 1: Assign peaks to traces (similar to Fortran's MTT output) for i, (pos, step, label) in enumerate( zip(peak_positions, peak_steps, cluster_labels) ): if label <= max_ntraces: trace_pts[label - 1, step] = pos # Step 2: Filter traces with significant number of points (like Fortran's 50% threshold) significant_traces = [] for trace_idx in range(n_clusters): n_points = np.sum(trace_pts[trace_idx, :] > 0) if n_points > 0.5 * nsteps: # Same threshold as Fortran significant_traces.append(trace_idx) ntraces = len(significant_traces) logger.debug(f"Number of significant traces: {ntraces} over {n_clusters} clusters") # Step 3: Sort traces by median position (like Fortran's sorting) if ntraces > 0: trace_medians = [] for trace_idx in significant_traces: valid_points = trace_pts[trace_idx, :] valid_points = valid_points[valid_points > 0] if len(valid_points) > 0: median_pos = np.median(valid_points) else: median_pos = 0.0 trace_medians.append(median_pos) # Sort by median position (ascending order) sorted_indices = np.argsort(trace_medians) significant_traces = [significant_traces[i] for i in sorted_indices] # Create final trace array with sorted traces final_trace_pts = np.zeros((max_ntraces, nsteps)) for i, trace_idx in enumerate(significant_traces): final_trace_pts[i, :] = trace_pts[trace_idx, :] return ntraces, final_trace_pts return 0, np.zeros((max_ntraces, nsteps))
[docs] def match_traces_to_fibres( instrument_code: int, traces: np.ndarray, fibre_types: np.ndarray, pk_posn: np.ndarray, args: Dict[str, Any], ) -> Tuple[np.ndarray, np.ndarray]: """ Match detected traces to fibre indices. Parameters ---------- instrument_code : int Instrument code traces : np.ndarray Detected traces fibre_types : np.ndarray Array of fibre types pk_posn : np.ndarray Peak positions args : dict Method arguments Returns ------- tuple (match_vector, modelled_fibre_positions) """ logger.info("Matching traces to fibres") nf = len(fibre_types) match_vector = np.zeros(nf, dtype=int) modelled_fibre_positions = np.zeros(nf) if instrument_code == INST_TAIPAN: # Get spectid from args or header? # make_tlm_other passed args. Usually SPECTID is in image header. # But make_tlm_other doesn't pass header values explicitly here except instrument_code. # However, match_traces_to_fibres signature has `args`. # Fortran: CALL TDFIO_KYWD_READ_CHAR(IM_ID,'SPECTID',SPECTID,CMT,STATUS) # We need access to the image file or pass SPECTID. # Current signature: (instrument_code, traces, fibre_types, pk_posn, args) # We can assume SPECTID is in args if passed, or we need to change signature/read it. # `read_instrument_data` gets `im_file`. `make_tlm_other` has `im_file`. # `match_traces_to_fibres` is called from `make_tlm_other`. # I should assume SPECTID is passed in args or available. # Let's assume it's in args['SPECTID'] which might be populated by caller? # Or I should add `spectid` to the function signature. # Since I can't easily change the call site in `make_tlm_other` without seeing it (it is in this file). # Let's check `make_tlm_other`. spectid = args.get("SPECTID", "RED") # Default to RED # Get nominal positions nf_taipan = len(fibre_types) ar_posn = taipan_nominal_fibpos(spectid, nf_taipan) # Match match_vector, modelled_fibre_positions = match_fibers_taipan( nf_taipan, fibre_types, pk_posn, ar_posn ) elif instrument_code == INST_ISOPLANE: # Simple 1-to-1 matching match_vector, modelled_fibre_positions = match_fibers_isoplane( len(fibre_types), pk_posn ) else: raise NotImplementedError( f"Trace matching for instrument {instrument_code} not implemented" ) return match_vector, modelled_fibre_positions
[docs] def convert_traces_to_tramline_map( traces: np.ndarray, match_vector: np.ndarray, nf: int ) -> np.ndarray: """ Convert identified traces to fibre tramline map array. Parameters ---------- traces : np.ndarray Detected traces match_vector : np.ndarray Vector matching fibre numbers to trace numbers nf : int Number of fibres Returns ------- np.ndarray Tramline map array """ nx, n_traces = traces.shape tramline_map = np.zeros((nx, nf)) n_missing = 0 for fibno in range(nf): traceno = match_vector[fibno] if traceno == 0: n_missing += 1 continue tramline_map[:, fibno] = traces[:, traceno - 1] # 0-based indexing logger.info(f"Converted traces to tramline map ({n_missing} missing fibres)") return tramline_map
[docs] def interpolate_tramlines( tramline_map: np.ndarray, match_vector: np.ndarray, sep: float ) -> None: """ Interpolate missing fibre traces. Parameters ---------- tramline_map : np.ndarray Tramline map array match_vector : np.ndarray Vector matching fibre numbers to trace numbers sep : float Nominal separation between fibres """ logger.info("Interpolating missing fibre traces") nx, nf = tramline_map.shape # Find first and last matched fibres matched_fibres = np.where(match_vector > 0)[0] if len(matched_fibres) < 2: logger.warning("Too few matched peaks to interpolate with") return first_matched = matched_fibres[0] last_matched = matched_fibres[-1] # Extrapolate from bottom end for fibno in range(first_matched - 1, -1, -1): delta = (first_matched - fibno) * sep tramline_map[:, fibno] = tramline_map[:, first_matched] - delta # Extrapolate from top end for fibno in range(last_matched + 1, nf): delta = (fibno - last_matched) * sep tramline_map[:, fibno] = tramline_map[:, last_matched] + delta # Interpolate for fibres with neighbours on both sides for fibno in range(first_matched + 1, last_matched): if match_vector[fibno] != 0: continue # Find nearest matched fibres above and below above_fibres = matched_fibres[matched_fibres > fibno] below_fibres = matched_fibres[matched_fibres < fibno] if len(above_fibres) == 0 or len(below_fibres) == 0: continue fibno_above = above_fibres[0] fibno_below = below_fibres[-1] # Linear interpolation lambda_val = (fibno - fibno_below) / (fibno_above - fibno_below) tramline_map[:, fibno] = (1.0 - lambda_val) * tramline_map[ :, fibno_below ] + lambda_val * tramline_map[:, fibno_above]
[docs] def interpolate_tramlines_taipan( tramline_map: np.ndarray, match_vector: np.ndarray, nominal_positions: np.ndarray ) -> None: """ Interpolate missing fibre traces for TAIPAN instrument. Parameters ---------- tramline_map : np.ndarray Tramline map array match_vector : np.ndarray Vector matching fibre numbers to trace numbers nominal_positions : np.ndarray Nominal fibre positions """ logger.info("Interpolating missing fibre traces for TAIPAN") nx, nf = tramline_map.shape # Similar to interpolate_tramlines but using nominal positions matched_fibres = np.where(match_vector > 0)[0] if len(matched_fibres) < 2: logger.warning("Too few matched peaks to interpolate with") return first_matched = matched_fibres[0] last_matched = matched_fibres[-1] # Extrapolate from bottom end for fibno in range(first_matched - 1, -1, -1): delta = nominal_positions[first_matched] - nominal_positions[fibno] tramline_map[:, fibno] = tramline_map[:, first_matched] - delta # Extrapolate from top end for fibno in range(last_matched + 1, nf): delta = nominal_positions[fibno] - nominal_positions[last_matched] tramline_map[:, fibno] = tramline_map[:, last_matched] + delta # Interpolate for fibres with neighbours on both sides for fibno in range(first_matched + 1, last_matched): if match_vector[fibno] != 0: continue above_fibres = matched_fibres[matched_fibres > fibno] below_fibres = matched_fibres[matched_fibres < fibno] if len(above_fibres) == 0 or len(below_fibres) == 0: continue fibno_above = above_fibres[0] fibno_below = below_fibres[-1] # Linear interpolation using nominal positions lambda_val = (nominal_positions[fibno] - nominal_positions[fibno_below]) / ( nominal_positions[fibno_above] - nominal_positions[fibno_below] ) tramline_map[:, fibno] = (1.0 - lambda_val) * tramline_map[ :, fibno_below ] + lambda_val * tramline_map[:, fibno_above]
[docs] def get_fibre_separation(instrument_code: int) -> float: """ Get nominal fibre separation for instrument. Parameters ---------- instrument_code : int Instrument code Returns ------- float Nominal fibre separation in pixels """ # Default separations (these would be instrument-specific) separations = { INST_2DF: 4.0, INST_6DF: 4.0, INST_AAOMEGA_2DF: 4.0, INST_HERMES: 4.0, INST_TAIPAN: 4.0, } return separations.get(instrument_code, 4.0)
[docs] def write_tramline_data( tlm_fname: str, tramline_map: np.ndarray, instrument_code: int, im_file: ImageFile ) -> None: """ Write tramline data to output file. Parameters ---------- tlm_fname : str Output filename tramline_map : np.ndarray Tramline map array instrument_code : int Instrument code im_file : ImageFile Image file handler """ logger.info(f"Writing tramline data to {tlm_fname}") # Create FITS file with tramline map from astropy.io import fits # Create primary HDU with tramline map hdu = fits.PrimaryHDU(tramline_map.T) # Transpose to match FITS convention # Add header keywords hdu.header["INSTRUME"] = f"INST_{instrument_code}" hdu.header["MWIDTH"] = 1.9 # TODO: pass from separate analysis, not hardcoded hdu.header["PSF_TYPE"] = "GAUSS" # TODO: pass from separate analysis, not hardcoded hdu.header["LAMBDAC"] = im_file.get_header_value("LAMBDAC", None) hdu.header["DISPERS"] = im_file.get_header_value("DISPERS", None) hdu.header["GRATID"] = im_file.get_header_value("GRATID", None) hdu.header["GRATLPMM"] = im_file.get_header_value("GRATLPMM", None) # Create HDU list hdul = fits.HDUList([hdu]) # Write to file hdul.writeto(tlm_fname, overwrite=True) hdul.close()
[docs] def predict_wavelength( im_file: ImageFile, tramline_map: np.ndarray, args: Dict[str, Any] ) -> np.ndarray: """ Predict wavelength for each pixel along each fibre. Parameters ---------- im_file : ImageFile Image file handler tramline_map : np.ndarray Tramline map array args : dict Method arguments Returns ------- np.ndarray Wavelength array """ logger.info("Predicting wavelength data") nspec, nf = tramline_map.shape instrument_code = im_file.get_instrument_code() if instrument_code == INST_TAIPAN: return predict_wavelength_taipan(im_file, nspec, nf) elif instrument_code == INST_ISOPLANE: return predict_wavelength_from_dispersion(im_file, nspec, nf) raise NotImplementedError( f"Wavelength prediction for instrument code {instrument_code} not yet implemented. " "This should implement the PREDICT_WAVELEN functionality from the Fortran code." )
[docs] def predict_wavelength_taipan(im_file: ImageFile, nspec: int, nf: int) -> np.ndarray: """ Predict wavelength for TAIPAN instrument (Fortran WLA_TAIPAN equivalent). Reads LAMBDAC and DISPERS from FITS header and computes wavelength for each pixel/fibre. """ try: lambdac_str, _ = im_file.read_header_keyword("LAMBDAC") dispers_str, _ = im_file.read_header_keyword("DISPERS") lambdac = float(lambdac_str) dispers = float(dispers_str) except Exception as e: logger.error(f"Error reading LAMBDAC or DISPERS from header: {e}") raise midpix = 0.5 * nspec wavelength_data = np.zeros((nspec, nf), dtype=np.float32) for pix in range(nspec): t = float(pix + 1) - 0.5 # Fortran 1-based index dist_from_midpix = t - midpix lam = dispers * (dist_from_midpix) + lambdac # Fortran code multiplies by 0.1 (presumably to convert to nm) value = lam * 0.1 wavelength_data[pix, :] = value return wavelength_data
[docs] def predict_wavelength_from_dispersion( im_file: ImageFile, nspec: int, nf: int ) -> np.ndarray: """ Predict wavelength from dispersion and central wavelength in the header. Parameters ---------- im_file : ImageFile Image file handler nspec : int Number of pixels in the dispersion direction nf : int Number of fibres Returns ------- np.ndarray Wavelength array """ midpix = 0.5 * nspec try: dispers_str, _ = im_file.read_header_keyword("DISPERS") lambdac_str, _ = im_file.read_header_keyword("LAMBDAC") dispers = float(dispers_str) lambdac = float(lambdac_str) except Exception as e: logger.error(f"Error reading DISPERS or LAMBDAC from header: {e}") raise dist_from_midpix = np.linspace(0.5, nspec + 0.5, nspec) - midpix wavevec = lambdac + dispers * dist_from_midpix # Angstroms wavelength_data = wavevec.reshape(nspec, 1).repeat(nf, axis=1) return wavelength_data
[docs] def write_wavelength_data(tlm_fname: str, wavelength_data: np.ndarray) -> None: """ Write wavelength data to tramline map file. Parameters ---------- tlm_fname : str Tramline map filename wavelength_data : np.ndarray Wavelength array """ logger.info("Writing wavelength data") from astropy.io import fits # Open existing file hdul = fits.open(tlm_fname, mode="update") # Create wavelength HDU hdu = fits.ImageHDU( wavelength_data.T, name="WAVELA" ) # Transpose to match FITS convention # Add to HDU list hdul.append(hdu) # Write changes hdul.flush() hdul.close()
def _wavelet_convolution(signal: np.ndarray, scale: float) -> np.ndarray: """ Perform wavelet convolution on a signal. This function implements a simplified version of the Fortran WAVELET_CONVOLUTION using a Mexican hat wavelet. Parameters ---------- signal : np.ndarray Input signal scale : float Wavelet scale parameter Returns ------- np.ndarray Convolved signal """ try: import pywt # Use Mexican hat wavelet (Ricker wavelet in pywt) # This is equivalent to the Fortran implementation wavelet = "mexh" # Mexican hat wavelet # Perform continuous wavelet transform # scales parameter determines the scale of the wavelet scales = np.array([scale]) coef, freqs = pywt.cwt(signal, scales, wavelet) # Return the real part of the wavelet coefficients return np.real(coef[0, :]) except ImportError: # Fallback to simple convolution if pywt is not available logger.warning("PyWavelets not available, using simple convolution") from scipy import signal as scipy_signal # Create a simple Gaussian kernel as fallback kernel_size = int(4 * scale) t = np.linspace(-kernel_size, kernel_size, 2 * kernel_size + 1) kernel = np.exp( -0.5 * (t / scale) ** 2 ) # FIXME: use mexican hat wavelet! This will just smooth the signal, not detect peaks. kernel = kernel / np.sum(kernel) # Normalize # Perform convolution convolved = scipy_signal.convolve(signal, kernel, mode="same") return convolved def _find_resonant_peaks_ztol(signal: np.ndarray, ztol: float) -> np.ndarray: """ Find resonant peaks in signal above zero tolerance. This function implements the Fortran WAVELET_FIND_RES_PEAKS_ZTOL algorithm. Parameters ---------- signal : np.ndarray Input signal ztol : float Zero tolerance threshold Returns ------- np.ndarray Indices of resonant peaks """ peaks = [] n = len(signal) # Find regions above zero tolerance and their maxima in_positive_range = False beg_idx = 0 for i in range(1, n - 1): if not in_positive_range: # Check if we are still in sub-zero range if signal[i] < ztol: continue # We are not in a sub-zero range, mark beginning in_positive_range = True beg_idx = i else: # Check if we are still in positive range if signal[i] >= ztol: continue # We have reached an end to positive range, find maximum in_positive_range = False end_idx = i - 1 # Find maximum between beg_idx and end_idx max_idx = beg_idx for j in range(beg_idx, end_idx + 1): if signal[j] > signal[max_idx]: max_idx = j # Add this peak to the list peaks.append(max_idx) return np.array(peaks, dtype=int) def _find_zero_crossings( signal: np.ndarray, peaks: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """ Find left and right zero crossings for each peak. This function implements the Fortran WAVELET_FIND_ZERO_CROSSINGS2 algorithm. Parameters ---------- signal : np.ndarray Input signal peaks : np.ndarray Peak indices Returns ------- tuple (lhs_zc, rhs_zc) - left and right zero crossing positions """ lhs_zc = [] rhs_zc = [] for peak_idx in peaks: if signal[peak_idx] <= 0.0: continue # Find left zero crossing zero_lhs = -1.0 for j in range(peak_idx, -1, -1): if signal[j] < 0.0: # Linear interpolation to find zero crossing j0, j1 = j, j + 1 if j1 < len(signal): zero_lhs = j0 - (j1 - j0) / (signal[j1] - signal[j0]) * signal[j0] break # Find right zero crossing zero_rhs = -1.0 for j in range(peak_idx, len(signal)): if signal[j] < 0.0: # Linear interpolation to find zero crossing j0, j1 = j - 1, j if j0 >= 0: zero_rhs = j0 - (j1 - j0) / (signal[j1] - signal[j0]) * signal[j0] break # Only add if both zero crossings were found if zero_lhs >= 0 and zero_rhs >= 0: lhs_zc.append(zero_lhs) rhs_zc.append(zero_rhs) return np.array(lhs_zc), np.array(rhs_zc) def _wavelet_peak_detection_scipy( col_data: np.ndarray, widths: list = None, max_peaks: int = None ) -> np.ndarray: """ Detect peaks using scipy's find_peaks_cwt method. This is an alternative to the Fortran-based implementation, using scipy's optimized wavelet peak detection. Parameters ---------- col_data : np.ndarray Column data to analyze widths : list, optional List of widths for wavelet analysis (default: [2, 4, 8]) max_peaks : int, optional Maximum number of peaks to return (default: None, return all) Returns ------- np.ndarray Peak positions """ from scipy.signal import find_peaks_cwt if widths is None: widths = [2, 4, 8] # Use scipy's find_peaks_cwt peaks = find_peaks_cwt(col_data, widths) # default: ricker wavelet # If we have more peaks than max_peaks, select the highest ones if max_peaks is not None and len(peaks) > max_peaks: # Get peak heights at the peak positions peak_heights = col_data[peaks.astype(int)] # Sort by peak height (descending) and take top max_peaks sorted_indices = np.argsort(peak_heights)[::-1] peaks = peaks[sorted_indices[:max_peaks]] return peaks.astype(float) def _find_zero_crossings_with_width( signal: np.ndarray, peaks: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: lhs_zc = [] rhs_zc = [] widths = [] for peak_idx in peaks: if signal[peak_idx] <= 0.0: continue # left zero_lhs = -1.0 for j in range(peak_idx, -1, -1): if signal[j] < 0.0: j0, j1 = j, j + 1 zero_lhs = j0 - (j1 - j0) / (signal[j1] - signal[j0]) * signal[j0] break # right zero_rhs = -1.0 for j in range(peak_idx, len(signal)): if signal[j] < 0.0: j0, j1 = j - 1, j zero_rhs = j0 - (j1 - j0) / (signal[j1] - signal[j0]) * signal[j0] break if zero_lhs >= 0 and zero_rhs >= 0: lhs_zc.append(zero_lhs) rhs_zc.append(zero_rhs) widths.append(zero_rhs - zero_lhs) return np.array(lhs_zc), np.array(rhs_zc), np.array(widths) def _filter_by_width( peaks: np.ndarray, widths: np.ndarray, min_keep: int, rel_tol: float = 0.35, ) -> Tuple[np.ndarray, np.ndarray]: """ Keep peaks whose widths are near the median width. rel_tol: relative tolerance to the median width (e.g. 0.35 -> median±35%) min_keep: minimum number of peaks to keep to prevent too many peaks from being filtered out (e.g. late columns) if the number of peaks is less than min_keep, do not apply width filter and return the original peaks and widths Returns ------- tuple (filtered_peaks, filtered_widths) """ if len(peaks) == 0: return peaks, widths w_med = float(np.median(widths)) m = (widths > (1 - rel_tol) * w_med) & (widths < (1 + rel_tol) * w_med) # safety check: if too few peaks are left, do not apply width filter if np.count_nonzero(m) < min_keep: return peaks, widths return peaks[m], widths[m] def _select_regular_run_by_spacing( peaks: np.ndarray, expected_n: int, ) -> np.ndarray: """ Select expected_n peaks from peaks, where the adjacent intervals are the most regular. """ if len(peaks) <= expected_n: return np.sort(peaks) p = np.sort(peaks) best_score = np.inf best_run = None # sliding window of length expected_n for i in range(0, len(p) - expected_n + 1): run = p[i : i + expected_n] d = np.diff(run) d_med = np.median(d) # score: relative dispersion (smaller is better) score = np.std(d / d_med) if d_med > 0 else np.inf # if the intervals are too irregular, it is not good if score < best_score: best_score = score best_run = run return best_run if best_run is not None else p[:expected_n] def _select_regular_run_by_spacing_and_height( peaks: np.ndarray, expected_n: int, heights: Optional[np.ndarray] = None, *, height_weight: float = 0.3, min_height_quantile: float = 0.0, ) -> np.ndarray: """ Select expected_n peaks from peaks such that: 1) Adjacent intervals are as regular as possible (primary objective), 2) If heights are provided, prefer runs with larger total peak height. Parameters ---------- peaks : np.ndarray Peak positions (1D). expected_n : int Number of peaks to select. heights : Optional[np.ndarray] Peak heights aligned with `peaks`. If None, selection is spacing-only. height_weight : float Weight of height term relative to regularity term. 0.0 -> identical to the original spacing-only behavior. Typical range: 0.1 ~ 1.0 (tune depending on your data). min_height_quantile : float Optionally ignore extremely tiny peaks globally by thresholding heights (e.g., 0.05 keeps top 95% by height). Set 0.0 to disable. Returns ------- np.ndarray Selected run of length expected_n (sorted by position). """ if len(peaks) <= expected_n: return np.sort(peaks) p = np.asarray(peaks) if heights is not None: h = np.asarray(heights) if h.shape != p.shape: raise ValueError("heights must have the same shape as peaks") else: h = None # Sort by peak position order = np.argsort(p) p = p[order] if h is not None: h = h[order] # Optional global height threshold to remove ultra-tiny junk peaks if min_height_quantile > 0.0: thr = np.quantile(h, min_height_quantile) keep = h >= thr # keep alignment p = p[keep] h = h[keep] if len(p) <= expected_n: return p # already sorted best = None best_final = np.inf # Normalize heights for comparable scaling across frames if h is not None: # robust scale: divide by median of top-k heights (avoid domination by one huge peak) k = min(10, len(h)) scale = np.median(np.sort(h)[-k:]) if k > 0 else 1.0 if scale <= 0: scale = 1.0 for i in range(0, len(p) - expected_n + 1): run = p[i : i + expected_n] d = np.diff(run) d_med = np.median(d) reg = np.std(d / d_med) if d_med > 0 else np.inf if h is None: final = reg else: run_h = h[i : i + expected_n] # Height term: larger total height -> smaller penalty # Use negative log to make it smooth and not overly sensitive. sum_h = float(np.sum(run_h) / scale) height_penalty = -np.log(max(sum_h, 1e-12)) final = reg + height_weight * height_penalty if final < best_final: best_final = final best = run return best if best is not None else p[:expected_n] def _robust_sigma_mad(x: np.ndarray) -> float: """ Robust noise estimate using MAD. sigma ~= 1.4826 * MAD """ x = np.asarray(x) med = np.nanmedian(x) mad = np.nanmedian(np.abs(x - med)) return 1.4826 * mad def _wavelet_convolution_multiscale( signal: np.ndarray, scales: Sequence[float] ) -> np.ndarray: """ Multi-scale Mexican-hat CWT response. Returns per-sample max positive response across scales. """ signal = np.asarray(signal, dtype=float) try: import pywt wavelet = "mexh" scales = np.asarray(list(scales), dtype=float) coef, _ = pywt.cwt(signal, scales, wavelet) # shape: (n_scales, n_samples) coef = np.real(coef) # Take the maximum response across scales at each sample. # Keep only positive response for peak detection stability. cwt_max = np.max(coef, axis=0) return cwt_max except ImportError: # Fallback: scipy ricker wavelet convolution at multiple widths from scipy import signal as sp_signal scales = np.asarray(list(scales), dtype=float) responses = [] n = signal.size for s in scales: # Build a ricker (mexican hat) kernel; width ~ s # Choose kernel size wide enough to capture lobes half = int(np.ceil(5 * s)) if half < 2: half = 2 t = np.arange(-half, half + 1, dtype=float) kernel = sp_signal.ricker(t.size, a=s) # mexican hat resp = sp_signal.convolve(signal, kernel, mode="same") responses.append(resp) coef = np.vstack(responses) cwt_max = np.max(coef, axis=0) return cwt_max def _wavelet_peak_detection( col_data: np.ndarray, scales: Union[float, Sequence[float]] = (1.5, 2.0, 2.5, 3.0), k_mad: float = 5.0, max_peaks: int = None, width_rel_tol: float = 0.35, ) -> np.ndarray: """ Detect peaks using multi-scale wavelet response + MAD-based threshold. Parameters ---------- col_data : np.ndarray Column data to analyze scales : float or sequence of float Wavelet scales (multi-scale). If float, it will be treated as [scales]. k_mad : float Threshold in units of MAD-sigma. Typical: 3~8 (start with 5). max_peaks : int, optional Maximum number of peaks to return Returns ------- np.ndarray Peak positions (float, sub-pixel via zero crossings) """ col_data = np.asarray(col_data, dtype=float) if isinstance(scales, (int, float)): scales = [float(scales)] # Step 1: Multi-scale wavelet response (per-pixel max across scales) cwt = _wavelet_convolution_multiscale(col_data, scales) # Step 2: MAD-based ztol (robust to strong fibers) sigma = _robust_sigma_mad(cwt) # If sigma is ~0 (very flat / saturated weirdness), fall back gently if not np.isfinite(sigma) or sigma <= 0: # fall back to 10% of maximum positive value ztol = 0.1 * np.max(cwt) # 10% of maximum positive value else: ztol = k_mad * sigma resonant_peaks = _find_resonant_peaks_ztol(cwt, ztol) # Step 3: zero crossings around resonant peaks lhs_zc, rhs_zc, widths = _find_zero_crossings_with_width(cwt, resonant_peaks) # Step 4: midpoint of zero crossings if len(lhs_zc) == 0 or len(rhs_zc) == 0: return np.array([], dtype=float) peak_positions = 0.5 * (lhs_zc + rhs_zc) logger.debug(f"# of peaks before width filtering: {len(peak_positions)}") # Step 5: filter out peaks with widths too far from the median width peak_positions, widths = _filter_by_width( peak_positions, widths, min_keep=max_peaks // 2, rel_tol=width_rel_tol ) logger.debug(f"# of peaks after width filtering: {len(peak_positions)}") # Step 6: select the most regular run of peaks + height preference if max_peaks is not None and len(peak_positions) > max_peaks: peak_heights = col_data[ np.clip(peak_positions.astype(int), 0, len(col_data) - 1) ] peak_positions = _select_regular_run_by_spacing_and_height( peak_positions, expected_n=max_peaks, heights=peak_heights, height_weight=0.1, # tune: 0.1~0.5 typical min_height_quantile=0.1, # optionally 0.05~0.2 to drop tiny junk peaks ) logger.debug(f"# of peaks after regular run selection: {len(peak_positions)}") return np.asarray(peak_positions, dtype=float) def _wavelet_peak_detection_old( col_data: np.ndarray, scale: float = 2.0, max_peaks: int = None ) -> np.ndarray: """ Detect peaks using wavelet convolution method. This function implements the complete wavelet-based peak detection algorithm from the Fortran code. Parameters ---------- col_data : np.ndarray Column data to analyze scale : float Wavelet scale parameter max_peaks : int, optional Maximum number of peaks to return (default: None, return all) Returns ------- np.ndarray Peak positions """ # Step 1: Perform wavelet convolution cwt = _wavelet_convolution(col_data, scale) # Step 2: Find resonant peaks above zero tolerance ztol = 0.1 * np.max(cwt) # 10% of maximum positive value resonant_peaks = _find_resonant_peaks_ztol(cwt, ztol) # Step 3: Find zero crossings for each peak lhs_zc, rhs_zc = _find_zero_crossings(cwt, resonant_peaks) # Step 4: Calculate peak positions as midpoints of zero crossings if len(lhs_zc) > 0 and len(rhs_zc) > 0: peak_positions = 0.5 * (lhs_zc + rhs_zc) # If we have more peaks than max_peaks, select the highest ones if max_peaks is not None and len(peak_positions) > max_peaks: # Get peak heights at the peak positions peak_heights = col_data[peak_positions.astype(int)] # Sort by peak height (descending) and take top max_peaks sorted_indices = np.argsort(peak_heights)[::-1] peak_positions = peak_positions[sorted_indices[:max_peaks]] return peak_positions.astype(float) else: return np.array([], dtype=float)