"""
Quality-control plotting utilities for kspecdr flux calibration.
All functions return ``matplotlib.figure.Figure`` objects and are designed
for interactive use in Jupyter notebooks. They are **not** called by the
pipeline itself.
Typical usage::
from kspecdr.fluxcal.qc import plot_calibration_summary
fig = plot_calibration_summary(result)
"""
from __future__ import annotations
import logging
from typing import Dict, List, Optional, Sequence
import numpy as np
from .containers import (
CalibrationVector,
FluxCalibrationResult,
Photometry,
Spectrum1D,
StellarTemplate,
)
logger = logging.getLogger(__name__)
__all__ = [
"plot_calibration_summary",
"plot_per_star_vectors",
"plot_calibration_residuals",
"plot_photometric_residuals",
"plot_template_match",
"plot_calibrated_spectrum",
"summarize_calibration",
]
def _import_plt():
"""Lazy matplotlib import to avoid hard dependency at module level."""
import matplotlib.pyplot as plt
return plt
# ---------------------------------------------------------------------------
# Summary overview
# ---------------------------------------------------------------------------
[docs]
def plot_calibration_summary(
result: FluxCalibrationResult,
title: str = "Flux Calibration Summary",
figsize: tuple = (14, 10),
):
"""Four-panel summary of the flux calibration result.
Panels:
1. Combined calibration vector Cal(λ)
2. Per-star calibration vectors (overlaid)
3. Per-star fractional residuals
4. Histogram of fractional residuals
Parameters
----------
result : FluxCalibrationResult
title : str
figsize : tuple
Returns
-------
matplotlib.figure.Figure
"""
plt = _import_plt()
fig, axes = plt.subplots(2, 2, figsize=figsize)
fig.suptitle(title, fontsize=14, y=0.98)
cv = result.combined_vector
wave = cv.wavelength
# Panel 1: combined vector
ax = axes[0, 0]
good = cv.mask
ax.plot(wave[good], cv.cal_factor[good], "k-", lw=0.8)
_shade_masked(ax, wave, cv.mask)
ax.set_xlabel("Wavelength (Å)")
ax.set_ylabel("Cal(λ)")
ax.set_title("Combined calibration vector")
ax.ticklabel_format(axis="y", style="sci", scilimits=(-2, 2))
# Panel 2: per-star vectors
ax = axes[0, 1]
colors = plt.cm.tab10(np.linspace(0, 1, max(result.n_stars, 1)))
for i, v in enumerate(result.per_star_vectors):
label = v.meta.get("star_name", f"star {i}")
ax.plot(wave[v.mask], v.cal_factor[v.mask],
lw=0.6, alpha=0.7, color=colors[i], label=label)
ax.set_xlabel("Wavelength (Å)")
ax.set_ylabel("Cal(λ)")
ax.set_title("Per-star vectors")
ax.ticklabel_format(axis="y", style="sci", scilimits=(-2, 2))
if result.n_stars <= 8:
ax.legend(fontsize=7, loc="best")
# Panel 3: fractional residuals
ax = axes[1, 0]
for i, resid in enumerate(result.per_star_residuals):
label = result.per_star_vectors[i].meta.get("star_name", f"star {i}")
good_r = result.per_star_vectors[i].mask & cv.mask
ax.plot(wave[good_r], resid[good_r] * 100,
lw=0.5, alpha=0.7, color=colors[i], label=label)
ax.axhline(0, color="k", lw=0.5, ls="--")
ax.set_xlabel("Wavelength (Å)")
ax.set_ylabel("Residual (%)")
ax.set_title("Fractional residuals (star − combined)")
rms = result.summary.get("rms_scatter", 0)
ax.set_ylim(-max(10 * rms * 100, 5), max(10 * rms * 100, 5))
# Panel 4: residual histogram
ax = axes[1, 1]
all_resid = []
for i, resid in enumerate(result.per_star_residuals):
good_r = result.per_star_vectors[i].mask & cv.mask
all_resid.extend(resid[good_r] * 100)
if all_resid:
ax.hist(all_resid, bins=50, color="steelblue", edgecolor="k", lw=0.3)
ax.axvline(0, color="k", lw=0.5, ls="--")
ax.set_xlabel("Residual (%)")
ax.set_ylabel("Count")
ax.set_title(f"Residual distribution (RMS = {rms*100:.2f}%)")
fig.tight_layout(rect=[0, 0, 1, 0.96])
return fig
# ---------------------------------------------------------------------------
# Individual per-star vector plot
# ---------------------------------------------------------------------------
[docs]
def plot_per_star_vectors(
result: FluxCalibrationResult,
log_scale: bool = False,
figsize: tuple = (12, 5),
):
"""Plot each per-star calibration vector on its own axis.
Parameters
----------
result : FluxCalibrationResult
log_scale : bool
figsize : tuple
Returns
-------
matplotlib.figure.Figure
"""
plt = _import_plt()
n = result.n_stars
fig, axes = plt.subplots(1, n, figsize=(figsize[0], figsize[1]),
sharey=True, squeeze=False)
for i, v in enumerate(result.per_star_vectors):
ax = axes[0, i]
m = v.meta
wave = v.wavelength
good = v.mask
ax.plot(wave[good], v.cal_factor[good], "k-", lw=0.6)
err = v.cal_error
ax.fill_between(wave[good],
v.cal_factor[good] - err[good],
v.cal_factor[good] + err[good],
alpha=0.2, color="steelblue")
_shade_masked(ax, wave, v.mask)
ax.set_xlabel("Wavelength (Å)")
if i == 0:
ax.set_ylabel("Cal(λ)")
name = m.get("star_name", f"star {i}")
teff = m.get("teff", 0)
ax.set_title(f"{name}\nTeff={teff:.0f} K", fontsize=9)
if log_scale:
ax.set_yscale("log")
fig.tight_layout()
return fig
# ---------------------------------------------------------------------------
# Calibration residuals
# ---------------------------------------------------------------------------
[docs]
def plot_calibration_residuals(
result: FluxCalibrationResult,
figsize: tuple = (12, 4),
):
"""Fractional residuals of each per-star vector relative to the combined.
Parameters
----------
result : FluxCalibrationResult
figsize : tuple
Returns
-------
matplotlib.figure.Figure
"""
plt = _import_plt()
fig, ax = plt.subplots(figsize=figsize)
cv = result.combined_vector
wave = cv.wavelength
colors = plt.cm.tab10(np.linspace(0, 1, max(result.n_stars, 1)))
for i, resid in enumerate(result.per_star_residuals):
v = result.per_star_vectors[i]
good = v.mask & cv.mask
label = v.meta.get("star_name", f"star {i}")
ax.plot(wave[good], resid[good] * 100,
lw=0.6, alpha=0.7, color=colors[i], label=label)
ax.axhline(0, color="k", lw=0.5, ls="--")
ax.set_xlabel("Wavelength (Å)")
ax.set_ylabel("Residual (%)")
rms = result.summary.get("rms_scatter", 0)
ax.set_title(f"Per-star residuals (RMS = {rms*100:.2f}%)")
ax.legend(fontsize=8)
fig.tight_layout()
return fig
# ---------------------------------------------------------------------------
# Photometric residuals
# ---------------------------------------------------------------------------
[docs]
def plot_photometric_residuals(
cal_vectors: List[CalibrationVector],
figsize: tuple = (8, 5),
):
"""Bar chart of per-band photometric residuals (synth − obs) for each star.
Parameters
----------
cal_vectors : list of CalibrationVector
Each must have ``meta["band_residuals"]`` (from
:func:`~.calibration.scale_template_to_photometry`).
figsize : tuple
Returns
-------
matplotlib.figure.Figure
"""
plt = _import_plt()
fig, ax = plt.subplots(figsize=figsize)
n_stars = len(cal_vectors)
all_bands = set()
for v in cal_vectors:
br = v.meta.get("band_residuals", {})
all_bands.update(br.keys())
bands = sorted(all_bands)
if not bands:
ax.text(0.5, 0.5, "No band residuals available",
transform=ax.transAxes, ha="center")
return fig
x = np.arange(len(bands))
width = 0.8 / max(n_stars, 1)
colors = plt.cm.tab10(np.linspace(0, 1, max(n_stars, 1)))
for i, v in enumerate(cal_vectors):
br = v.meta.get("band_residuals", {})
vals = [br.get(b, 0.0) for b in bands]
label = v.meta.get("star_name", f"star {i}")
ax.bar(x + i * width - 0.4 + width / 2, vals, width,
color=colors[i], edgecolor="k", lw=0.3, label=label, alpha=0.8)
ax.axhline(0, color="k", lw=0.5, ls="--")
ax.set_xticks(x)
ax.set_xticklabels(bands, rotation=45, ha="right", fontsize=8)
ax.set_ylabel("Synth − Obs (mag)")
ax.set_title("Photometric residuals after scaling")
ax.legend(fontsize=8)
fig.tight_layout()
return fig
# ---------------------------------------------------------------------------
# Template match inspection
# ---------------------------------------------------------------------------
[docs]
def plot_template_match(
observed: Spectrum1D,
template: Spectrum1D,
rv_kms: float = 0.0,
title: str = "",
figsize: tuple = (12, 6),
):
"""Compare observed and best-matching template spectra.
Shows raw flux (top) and continuum-normalised (bottom, if continuum
is available in template meta).
Parameters
----------
observed : Spectrum1D
Observed standard-star spectrum.
template : Spectrum1D
Prepared (convolved + resampled) template spectrum.
rv_kms : float
Measured RV in km/s (annotated on the plot).
title : str
figsize : tuple
Returns
-------
matplotlib.figure.Figure
"""
plt = _import_plt()
has_cont = "continuum" in template.meta and template.meta["continuum"] is not None
n_panels = 2 if has_cont else 1
fig, axes = plt.subplots(n_panels, 1, figsize=figsize, sharex=True)
if n_panels == 1:
axes = [axes]
wave = observed.wavelength
good = observed.mask
# Top: raw flux
ax = axes[0]
ax.plot(wave[good], observed.flux[good], "k-", lw=0.5, label="Observed", alpha=0.8)
ax.plot(wave, template.flux, "r-", lw=0.5, label="Template", alpha=0.7)
ax.set_ylabel("Flux")
teff = template.meta.get("teff", "?")
logg = template.meta.get("logg", "?")
feh = template.meta.get("feh", "?")
ax.legend(fontsize=8, title=f"Teff={teff} logg={logg} [M/H]={feh} RV={rv_kms:.1f} km/s")
if title:
ax.set_title(title)
# Bottom: normalised
if has_cont:
cont = template.meta["continuum"]
safe_cont = np.where(cont > 0, cont, 1.0)
ax = axes[1]
# Normalise observed by fitting or by template continuum
obs_norm = observed.flux / safe_cont
tmpl_norm = template.flux / safe_cont
ax.plot(wave[good], obs_norm[good], "k-", lw=0.5, alpha=0.8, label="Obs / cont")
ax.plot(wave, tmpl_norm, "r-", lw=0.5, alpha=0.7, label="Tmpl / cont")
ax.axhline(1.0, color="gray", lw=0.3, ls="--")
ax.set_ylabel("Normalised flux")
ax.set_ylim(0.3, 1.3)
ax.legend(fontsize=8)
axes[-1].set_xlabel("Wavelength (Å)")
fig.tight_layout()
return fig
# ---------------------------------------------------------------------------
# Calibrated spectrum
# ---------------------------------------------------------------------------
[docs]
def plot_calibrated_spectrum(
wavelength: np.ndarray,
flux: np.ndarray,
variance: np.ndarray = None,
fiber_id: int = -1,
title: str = "",
figsize: tuple = (12, 4),
):
"""Quick plot of a flux-calibrated spectrum with optional error band.
Parameters
----------
wavelength : ndarray
In Angstrom.
flux : ndarray
In erg/s/cm²/Å.
variance : ndarray, optional
If provided, ±1σ band is shown.
fiber_id : int
Annotated on the plot.
title : str
figsize : tuple
Returns
-------
matplotlib.figure.Figure
"""
plt = _import_plt()
fig, ax = plt.subplots(figsize=figsize)
good = np.isfinite(flux) & (flux > 0)
ax.plot(wavelength[good], flux[good], "k-", lw=0.5)
if variance is not None:
err = np.sqrt(np.where(variance > 0, variance, 0.0))
ax.fill_between(
wavelength[good],
flux[good] - err[good],
flux[good] + err[good],
alpha=0.2, color="steelblue",
)
ax.set_xlabel("Wavelength (Å)")
ax.set_ylabel(r"$f_\lambda$ (erg s$^{-1}$ cm$^{-2}$ Å$^{-1}$)")
ax.ticklabel_format(axis="y", style="sci", scilimits=(-2, 2))
t = title or f"Fiber {fiber_id}" if fiber_id >= 0 else "Calibrated spectrum"
ax.set_title(t)
fig.tight_layout()
return fig
# ---------------------------------------------------------------------------
# Text summary table
# ---------------------------------------------------------------------------
[docs]
def summarize_calibration(result: FluxCalibrationResult) -> str:
"""Return a formatted text summary of the calibration result.
Suitable for printing in a notebook cell or logging.
Parameters
----------
result : FluxCalibrationResult
Returns
-------
str
"""
lines = []
s = result.summary
lines.append("=" * 60)
lines.append(" Flux Calibration Summary")
lines.append("=" * 60)
lines.append(f" Stars used : {s.get('n_stars_used', '?')}")
lines.append(f" Stars rejected : {s.get('n_stars_rejected', 0)}")
rms = s.get("rms_scatter", 0)
lines.append(f" RMS scatter : {rms*100:.2f}%")
wr = s.get("wavelength_range", (0, 0))
lines.append(f" Wavelength range : {wr[0]:.0f} – {wr[1]:.0f} Å")
lines.append(f" Good pixels : {result.combined_vector.n_good}")
lines.append("-" * 60)
for i, v in enumerate(result.per_star_vectors):
m = v.meta
lines.append(
f" Star {i}: {m.get('star_name', '?'):>12s} "
f"fiber={m.get('fiber_id', -1):>3d} "
f"Teff={m.get('teff', 0):5.0f} "
f"logg={m.get('logg', 0):.1f} "
f"[M/H]={m.get('feh', 0):+.2f} "
f"RV={m.get('rv_kms', 0):+6.1f} km/s "
f"scale={m.get('scale_factor', 0):.3e}"
)
br = m.get("band_residuals", {})
if br:
resid_str = " ".join(f"{b}={r:+.3f}" for b, r in sorted(br.items()))
lines.append(f" resid(mag): {resid_str}")
lines.append("=" * 60)
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _shade_masked(ax, wavelength, mask):
"""Add light red shading for masked (False) regions."""
bad = ~mask
if not bad.any():
return
# Find contiguous bad regions
diff = np.diff(bad.astype(int))
starts = np.where(diff == 1)[0] + 1
ends = np.where(diff == -1)[0] + 1
if bad[0]:
starts = np.concatenate([[0], starts])
if bad[-1]:
ends = np.concatenate([ends, [len(wavelength) - 1]])
for s, e in zip(starts, ends):
ax.axvspan(wavelength[s], wavelength[min(e, len(wavelength) - 1)],
alpha=0.08, color="red", zorder=0)