Source code for kspecdr.wavecal.wavelets

"""
Wavelet functions for peak detection and signal analysis.
Implements Mexican Hat, Haar, and N2BSpline wavelets and convolution routines.
"""

import numpy as np
import logging
from scipy import signal as scipy_signal

logger = logging.getLogger(__name__)

# Constants
ROOT_2PI = 2.506628274


[docs] def mexican_hat_wavelet(t: np.ndarray, sigma: float = 1.0) -> np.ndarray: """ Calculate the Mexican Hat wavelet at points t. Formula: A * (1 - z^2) * exp(-0.5 * z^2) where z = t / sigma, A = 1 / (sqrt(2*pi) * sigma^3) """ z_sq = (t / sigma) ** 2 term = 1.0 / (ROOT_2PI * sigma**3) return term * (1 - z_sq) * np.exp(-0.5 * z_sq)
[docs] def wavelet_convolution(signal: np.ndarray, t: np.ndarray, scale: float) -> np.ndarray: """ Perform continuous wavelet transform convolution using Mexican Hat wavelet. Parameters ---------- signal : np.ndarray Input signal. t : np.ndarray Time/position axis for the signal. Assumed to be regularly spaced for convolution. scale : float Wavelet scale parameter (a). Returns ------- np.ndarray Convolved signal (wavelet coefficients). """ # Create the wavelet kernel # The kernel needs to be wide enough to capture the wavelet shape. # Mexican hat decays quickly. +/- 5*scale is usually sufficient. dt = t[1] - t[0] if len(t) > 1 else 1.0 half_width = int(np.ceil(5.0 * scale / dt)) if half_width < 1: half_width = 1 # Kernel support in units of t t_kernel = np.arange(-half_width, half_width + 1) * dt # Calculate daughter wavelet: 1/sqrt(a) * psi((t-b)/a) # For convolution, b is the shift, handled by convolve operation. # We compute 1/sqrt(scale) * psi(t_kernel / scale) # However, standard CWT definition involves integral. # Discretized: sum( signal * wavelet * dt ) psi = mexican_hat_wavelet(t_kernel, sigma=1.0) # Base wavelet # Rescale for the daughter wavelet # The Fortran code: DAUGHTER_WAVELET(A,B,T) = 1.0/SQRT(A)*MOTHER_WAVELET((T-B)/A) # Here MOTHER_WAVELET is mexican_hat_wavelet(T) with sigma=1.0. # The kernel evaluated at t_kernel corresponding to (t-b) # Let tau = t_kernel. We want 1/sqrt(a) * psi(tau/a) kernel_vals = mexican_hat_wavelet(t_kernel / scale, sigma=1.0) * ( 1.0 / np.sqrt(scale) ) # Multiply by dt for the integral approximation kernel = kernel_vals * dt # Convolve # mode='same' returns output of same length as signal convolved = scipy_signal.convolve(signal, kernel, mode="same") return convolved
[docs] def wavelet_find_res_peaks_ztol( signal: np.ndarray, t: np.ndarray, ztol: float ) -> np.ndarray: """ Find resonant peaks in signal above zero tolerance. Returns indices of peaks. """ n = len(signal) peaks = [] in_positive_range = False beg_idx = 0 for i in range(1, n - 1): # Fortran 2..N-1 (1-based), so 1..N-2 (0-based) if in_positive_range: if signal[i] < ztol: # End of positive range in_positive_range = False end_idx = i - 1 # Find max in range [beg_idx, end_idx] max_idx = beg_idx + np.argmax(signal[beg_idx : end_idx + 1]) peaks.append(max_idx) else: if signal[i] >= ztol: in_positive_range = True beg_idx = i return np.array(peaks, dtype=int)
[docs] def wavelet_find_zero_crossings2( signal: np.ndarray, t: np.ndarray, peaks: np.ndarray, ztol: float ) -> tuple[np.ndarray, np.ndarray]: """ Find LHS and RHS zero crossings for each peak. """ lhs_zc = [] rhs_zc = [] n = len(signal) for p_idx in peaks: if signal[p_idx] <= 0.0: continue # Find LHS zero crossing zero_lhs = -1.0 for j in range(p_idx, 0, -1): # Scan left if signal[j] < 0.0: # S(j) < 0, S(j+1) > 0 (since we started from peak>0) # Note: Fortran loop `DO J=PIX(I),1,-1`, if S(J)<0 then interpolate J, J+1 j0, j1 = j, j + 1 if j1 < n: denom = signal[j1] - signal[j0] if denom != 0: zero_lhs = t[j0] - (t[j1] - t[j0]) / denom * signal[j0] break # Find RHS zero crossing zero_rhs = -1.0 for j in range(p_idx, n): # Scan right if signal[j] < 0.0: # S(j) < 0, S(j-1) > 0 j0, j1 = j - 1, j if j0 >= 0: denom = signal[j1] - signal[j0] if denom != 0: zero_rhs = t[j0] - (t[j1] - t[j0]) / denom * signal[j0] break if zero_lhs >= 0.0 and zero_rhs >= 0.0: lhs_zc.append(zero_lhs) rhs_zc.append(zero_rhs) return np.array(lhs_zc), np.array(rhs_zc)
[docs] def find_resonant_peaks2(signal: np.ndarray, t: np.ndarray, ztol: float) -> np.ndarray: """ Find peaks using zero crossings logic (WAVELET_FIND_RES_PEAKS2). Returns peak locations (interpolated float positions). """ # 1. Find integer peak indices first peak_indices = wavelet_find_res_peaks_ztol(signal, t, ztol) # 2. Find zero crossings around these peaks lhs, rhs = wavelet_find_zero_crossings2(signal, t, peak_indices, ztol) # 3. Peak location is midpoint of zero crossings if len(lhs) > 0: return 0.5 * (lhs + rhs) else: return np.array([])
[docs] def calc_medmad(data: np.ndarray) -> tuple[float, float]: """Calculate median and MAD of data.""" valid = data[np.isfinite(data)] if len(valid) == 0: return 0.0, 0.0 med = np.median(valid) mad = np.median(np.abs(valid - med)) return med, mad
[docs] def analyse_arc_signal(arc_sig: np.ndarray) -> tuple[float, float, float, float, float]: """ Analyse arc signal to estimate noise and PSF parameters. Returns ------- mn_noise, sd_noise, al_sigma, ares, ztol """ n = len(arc_sig) # 1. Estimate noise mn_noise, sd_noise = calc_medmad(arc_sig) # Also recalculate sd_noise excluding median values (as per Fortran) # "CNT=0... IF (ABS(ARCSIG(I)-MED)==0) CYCLE..." # Actually just standard MAD is fine, but Fortran does a second pass. # We will stick to the first MAD for simplicity or implement exact if needed. # The Fortran second pass calculates MAD of ABS(ARCSIG - MED). # But calc_medmad already does MAD. # Fortran: # CALL CALC_MEDMAD(TMPV,CNT,MED,MAD) -> MN=MED, SD=MAD # Loop... TMPV(CNT)=ABS(ARCSIG(I)-MED) # CALL CALC_MEDMAD(TMPV,CNT,MED,MAD) -> SD=MED # So the second SD estimate is the Median of the Absolute Deviations. # Which IS the definition of MAD. So the first call gave MAD, the second gave Median of AD? # Wait. First call gives Median of signal. MAD of signal. # Second call computes Median of |signal - Median|. Which is MAD. # So SD_NOISE is indeed MAD. # 2. Find arc line sigma # Noise cutoff noise_cutoff = mn_noise + 3.0 * sd_noise # CWT with scale 1.0 scale = 1.0 t = np.arange(n, dtype=float) cwt = wavelet_convolution(arc_sig, t, scale) # Find zero crossings ztol = 0.01 * np.max(cwt) # We need peak indices first to find zero crossings peak_indices = wavelet_find_res_peaks_ztol(cwt, t, ztol) lhs, rhs = wavelet_find_zero_crossings2(cwt, t, peak_indices, ztol) widths = rhs - lhs # Estimate sigma from widths: (gap/2)^2 = scale^2 + sigma^2 # sigma = sqrt( (gap/2)^2 - scale^2 ) valid_sigmas = [] for w in widths: if w > 0: val = (0.5 * w) ** 2 - scale**2 if val > 0: valid_sigmas.append(np.sqrt(val)) if valid_sigmas: al_sigma = np.median(valid_sigmas) else: al_sigma = 1.0 # Fallback # 3. Wavelet resonance analysis ares = np.sqrt(5.0) * al_sigma # Recalculate CWT at resonance scale to get ZTOL cwt_res = wavelet_convolution(arc_sig, t, ares) max_sig = np.max(arc_sig) max_cwt = np.max(cwt_res) if max_sig > 0: ztol = (3 * sd_noise) * max_cwt / max_sig else: ztol = 0.0 return mn_noise, sd_noise, al_sigma, ares, ztol