Source code for kspecdr.wavecal.landmarks

"""
Landmark registration and signal synchronization.
"""

import numpy as np
import logging
from scipy.interpolate import interp1d
from sklearn.linear_model import RANSACRegressor
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LinearRegression

from kspecdr.tracking import multi_target_tracking
from .wavelets import (
    wavelet_convolution,
    find_resonant_peaks2,
)

logger = logging.getLogger(__name__)


[docs] def robust_polyfit(x, y, order): """ Robust polynomial fitting using RANSAC. """ if len(x) < order + 1: # Fallback to standard polyfit if not enough points return np.polyfit(x, y, order) try: model = make_pipeline( PolynomialFeatures(order), RANSACRegressor(LinearRegression(), random_state=42), ) model.fit(x.reshape(-1, 1), y) # Extract coefficients - this is a bit tricky with Pipeline/RANSAC # Easier to just predict # Or use the inlier mask to do numpy polyfit inlier_mask = model.named_steps["ransacregressor"].inlier_mask_ if np.sum(inlier_mask) < order + 1: return np.polyfit(x, y, order) return np.polyfit(x[inlier_mask], y[inlier_mask], order) except Exception: return np.polyfit(x, y, order)
[docs] def windsor_istats( ivec: np.ndarray, n: int, cutoff_percent: float ) -> tuple[float, float, bool]: """ Perform Windsor Statistics analysis on an integer vector. Returns (mean, sd, check_flag). Calculates the mean and standard deviation of the data, but first truncates the data to the range [cutoff_percent, 100-cutoff_percent] (actually just 100-cutoff_percent logic here to match Fortran description approx). Actually, standard Windsor statistics usually involves replacing tails. The Fortran code likely does a robust mean/sd calculation assuming X% of data is good. Given "assumption that 75% of the data is good" in comments. Implementation based on typical Windsorized Mean/SD or Truncated Mean/SD logic. Here we implement a simplified version consistent with common robust stats: Sort data, take the middle `cutoff_percent` (e.g. 75%), calculate mean/std of that. Returns ------- mean : float sd : float check_flag : bool True if quality seems okay (sd < mean + some tolerance?), or just always True unless empty? Fortran code says: "Output warning if sanity check fails." """ if n < 1: return 0.0, 0.0, False # Filter out zeros or bad values if needed? Fortran passes N_ARCLINES found. # Assuming IVEC contains counts. # Sort the vector sorted_vec = np.sort(ivec[:n]) # Determine indices for truncation # e.g. 75% -> we keep the "best" 75%? Or the middle 75%? # Usually "assumption 75% good" implies outliers might be 25%. # We will just take the interquartile range or similar. # Let's keep the middle `cutoff_percent` percent. # 1. Selection # If 75%, we might discard top 12.5% and bottom 12.5%. # But usually for "number of arc lines", low numbers are bad (outliers), high numbers are good/normal. # Or high numbers might be noise. # Let's stick to standard trimmed mean behavior: trim both ends. k = int(n * (100.0 - cutoff_percent) / 200.0) # Amount to trim from each end # e.g. 75% -> trim 12.5% from each end. # if n=100, trim 12 from each end. keep 25-100? No. # trim = 100 * 25 / 200 = 12.5. start_idx = k end_idx = n - k if start_idx >= end_idx: # Fallback to full stats if too few points subset = sorted_vec else: subset = sorted_vec[start_idx:end_idx] if len(subset) == 0: return 0.0, 0.0, False win_mn = float(np.mean(subset)) win_sd = float(np.std(subset)) # Sanity check logic # "Warning quality of fibre arclines may be compromised" # If SD is very high compared to Mean? # Or if Mean is too low? # Fortran usually does: IF (ABS(WIN_MN - MEDIAN) > ...) or if SD > ... # Without the Fortran source for WINDSOR_ISTATS, we assume a basic check. # If SD > Mean, it's definitely suspicious for counts. chk_flag = True if win_mn > 0 and (win_sd / win_mn) > 0.5: # If variation is > 50% of mean, that's messy. chk_flag = False return win_mn, win_sd, chk_flag
[docs] def landmark_register( spectra: np.ndarray, npix: int, nfib: int, maskv: np.ndarray, ref_fib: int, scale: float, ztol: float, diagnostic: bool = False, ) -> tuple[np.ndarray, int]: """ Register and align landmarks. Parameters ---------- spectra : np.ndarray Input spectra (npix, nfib) - Fortran uses (NPIX, NFIB) npix : int Number of pixels nfib : int Number of fibers maskv : np.ndarray Mask vector (True if masked/bad) ref_fib : int Reference fiber index scale : float Wavelet scale ztol : float Zero tolerance for peak finding diagnostic : bool Whether to print diagnostic info Returns ------- lmr : np.ndarray Landmark Register Array (nfib, nlm) Note: Fortran defines LMR(NFIB, NPIX) but fills it sparsely? The python return should be (nfib, nlm) where nlm is the number of found landmarks. Wait, Fortran: REAL, INTENT(OUT) :: LMR(NFIB,NPIX). And: LMR(FIBNO,USEIDX)=TRACKA(I,SEQIDX). So it returns a 2D array where the 2nd dimension is the landmark index (up to NPIX max). We will return (nfib, nlm) sized array. nlm : int Number of landmarks """ logger.info("Landmark Registration...") # 1. Find landmarks in each fibre t = np.arange(npix, dtype=float) # SEQ_A: Sequence array for MTT. # Fortran: SEQ_A(NPIX, NFIB) -> stores peak positions. # In Python, we can just use a list of arrays or a large array. # The Fortran code fills SEQ_A(cnt, nseq) = peak_pos. # And keeps a map SEQMAPV(nseq) = fibno. # It skips masked fibers. # We will emulate this structure for clarity and MTT input. # But our MTT expects `pk_grid[step, peak_idx]`. # nsteps (sequences) = number of good fibers. # First pass: Count good fibers and collect peaks good_fibers = [] peaks_per_fiber = [] # list of arrays logger.info("-> Identifying landmarks within each extracted fibre") for fib in range(nfib): if maskv[fib]: continue # Progress logging could be added here signal = spectra[:, fib] # Handle bad values (NaNs or specific bad values) # Assumed handled or simple replacement signal = np.nan_to_num(signal) # Wavelet convolution cwt = wavelet_convolution(signal, t, scale) # Determine ZTOL _ztol = ztol if _ztol <= 0: # 10% of max mx = np.max(cwt) _ztol = 0.1 * mx # logger.debug(f"Using auto ztol={_ztol} for fiber {fib}") # Find peaks pks = find_resonant_peaks2(cwt, t, _ztol) # Filter peaks close to bad pixels? # "Store all peaks found that are not on or next to a bad pixel" # We did nan_to_num, so check original spectra for bad values if needed. # For now, assume clean or handled by nan_to_num. # Add to list peaks_per_fiber.append(pks) good_fibers.append(fib) nseq = len(good_fibers) if nseq == 0: logger.warning("No good fibers found for landmark registration.") return np.zeros((nfib, 0)), 0 # Sanity Check: Windsor Stats counts = np.array([len(p) for p in peaks_per_fiber]) win_mn, win_sd, chk_flag = windsor_istats(counts, nseq, 75.0) if not chk_flag: logger.warning("Warning: quality of fibre arclines may be compromised (high variance in line counts)") # Continue anyway as per Fortran? "Output warning...". Yes. # Prepare Data for MTT # MTT expects `pk_grid` of shape (nsteps, max_ntraces). # nsteps = nseq. # max_ntraces? We don't strictly know, but it's roughly the number of arc lines. # Let's find the max number of peaks found in any fiber. max_peaks_found = np.max(counts) if len(counts) > 0 else 0 # Add some buffer? Fortran uses NPIX size for SEQ_A. # Our MTT implementation handles sparse arrays if we pass correct shape. # Let's allocate `pk_grid` with `max_peaks_found`. pk_grid = np.zeros((nseq, max_peaks_found)) for i, pks in enumerate(peaks_per_fiber): n_p = len(pks) if n_p > 0: pk_grid[i, :n_p] = np.sort(pks)[:max_peaks_found] # Ensure sorted and fits # Run Multi-Target Tracking logger.info("Tracking arc landmarks from fibre to fibre") # Parameters for MTT # Fortran: CALL MULTI_TARGET_TRACKING(SEQ_A,NPIX,NSEQ,NTRACKS,TRACKA,20.0) # MAX_DISPLACEMENT = 20.0 (pixels)? max_disp = 20.0 ntracks, tracka = multi_target_tracking(pk_grid, nseq, max_peaks_found, max_disp) logger.info(f"Found {ntracks} Tracks") # Filter tracks (FIND TRACKS THAT CAN BE USED TO MODEL THE DISTORTION) # "Find a set of tracks that are present in over 50% of all sequences." # (Comment says 50%, code says > 75% in the provided snippet: `IF (PERCENT>75.0) THEN`) # Wait, the snippet says `IF (PERCENT>75.0) THEN` but commented out `IF (PERCENT>50.0)`. # I will use 75% to match the active code in the snippet. min_percent = 75.0 valid_track_indices = [] # tracka shape is (max_ntraces, nseq) = (max_peaks_found, nseq) # Fortran output TRACKA(NPIX, NFIB) -> (TrackID, SequenceID). # Wait, Fortran TRACKA indices: # `TRACKA(I, SEQIDX)` where I is track index, SEQIDX is sequence index. # Our `multi_target_tracking` returns `trace_pts` (ntraces, nsteps). # which is (max_ntraces, nseq). # So `tracka[i, :]` is the i-th track across all sequences. for i in range(ntracks): # Count non-zeros n_present = np.count_nonzero(tracka[i, :]) percent = 100.0 * n_present / nseq if percent > min_percent: valid_track_indices.append(i) nuse = len(valid_track_indices) if nuse < 3: logger.warning("Warning! Unable to trace enough strong arcs fully down the image") # Compile LMR # LMR(NFIB, NLM) # We need to map back from Sequence Index to Fiber Index. # `good_fibers[seq_idx]` gives the fiber index. lmr = np.zeros((nfib, nuse)) for use_idx, track_idx in enumerate(valid_track_indices): track_data = tracka[track_idx, :] # Shape (nseq,) for seq_idx, pos in enumerate(track_data): if pos > 0: fib_no = good_fibers[seq_idx] lmr[fib_no, use_idx] = pos return lmr, nuse
[docs] def synchronise_signals( spectra: np.ndarray, npix: int, nfib: int, maskv: np.ndarray, ref_fib: int, lmr: np.ndarray, nlm: int, ) -> np.ndarray: """ Rebin spectra to align landmarks. Iterates outwards from ref_fib to propagate calibration on failure. """ rebin_spectra = np.zeros_like(spectra) axis1 = np.arange(npix, dtype=float) # Reference axis # Split loop into two legs: Down (Ref -> 0) and Up (Ref+1 -> NFIB) # Default Identity coeffs for deg=2 (y=x): [0, 1, 0] default_coeffs = np.array([0.0, 1.0, 0.0]) legs = [ range(ref_fib, -1, -1), range(ref_fib + 1, nfib) ] for leg in legs: last_good_coeffs = default_coeffs.copy() for fib in leg: if maskv[fib]: continue # Get landmarks x_pts = [] # In this fibre y_pts = [] # In ref fibre for i in range(nlm): p_fib = lmr[fib, i] p_ref = lmr[ref_fib, i] if p_fib > 0 and p_ref > 0: x_pts.append(p_fib) y_pts.append(p_ref) coeffs = last_good_coeffs if len(x_pts) >= 3: x_pts = np.array(x_pts) y_pts = np.array(y_pts) # Fit mapping: ref_pos = f(fib_pos). coeffs = robust_polyfit(x_pts / npix, y_pts / npix, 2) last_good_coeffs = coeffs else: logger.warning(f"Synchronise Signals: Fibre {fib} has insufficient landmarks ({len(x_pts)}). Using neighbor coefficients.") # axis2 = f(axis1) axis2_norm = np.polyval(coeffs, axis1 / npix) axis2 = axis2_norm * npix isfinite = np.isfinite(spectra[:, fib]) if np.any(isfinite): f_interp = interp1d( axis2[isfinite], spectra[:, fib][isfinite], kind="linear", bounds_error=False, fill_value=0.0 ) rebin_spectra[:, fib] = f_interp(axis1) return rebin_spectra
[docs] def synchronise_calibration_last( cal_axis: np.ndarray, npix: int, nfib: int, maskv: np.ndarray, ref_fib: int, lmr: np.ndarray, nlm: int, ) -> np.ndarray: """ Synchronise calibration from ref fibre to others. Iterates outwards from ref_fib to propagate calibration on failure. cal_axis: Calibration of reference fibre (wavelengths). """ synchcal_axes = np.zeros((nfib, npix + 1)) # cal_axis has length NPIX+1 (edges) axis1 = np.arange(npix + 1, dtype=float) # Default Identity coeffs for deg=3 (y=x): [0, 0, 1, 0] default_coeffs = np.array([0.0, 0.0, 1.0, 0.0]) legs = [ range(ref_fib, -1, -1), range(ref_fib + 1, nfib) ] for leg in legs: last_good_coeffs = default_coeffs.copy() for fib in leg: if maskv[fib]: continue x_pts = [] # In this fibre (pixel) y_pts = [] # In ref fibre (pixel) for i in range(nlm): p_fib = lmr[fib, i] p_ref = lmr[ref_fib, i] if p_fib > 0 and p_ref > 0: x_pts.append(p_fib) y_pts.append(p_ref) coeffs = last_good_coeffs if len(x_pts) >= 3: x_pts = np.array(x_pts) y_pts = np.array(y_pts) # Map this fibre pixels -> ref fibre pixels coeffs = robust_polyfit(x_pts / npix, y_pts / npix, 3) # Cubic last_good_coeffs = coeffs else: logger.warning(f"Synchronise Calibration: Fibre {fib} has insufficient landmarks ({len(x_pts)}). Using neighbor coefficients.") axis1_norm = axis1 / npix axis2_norm = np.polyval(coeffs, axis1_norm) axis2 = axis2_norm * npix # axis2 contains coordinates in Ref Fibre Pixels. # We know Ref Fibre Pixels -> Wavelength (cal_axis). # Interpolate Wavelength at axis2. f_interp = interp1d( axis1, cal_axis, kind="linear", bounds_error=False, fill_value="extrapolate" ) synchcal_axes[fib, :] = f_interp(axis2) return synchcal_axes