Solver Accuracy Comparison: LSQR vs FISTA Group-L1#

Wide-field slitless spectroscopy (WFSS) with JWST NIRISS disperses the light of every object in the field simultaneously, causing spectra from neighbouring sources to overlap on the detector. In crowded environments — star clusters, the cores of nearby galaxies, or deep extragalactic fields — this source confusion degrades spectral fidelity unless the extraction algorithm explicitly models and separates each contribution. Here we compare two solvers: SpectralSolver, which minimises a Tikhonov-regularised least-squares problem via LSQR, and JAXProximalSolver, which applies group-\(\ell_1\) regularisation via the FISTA proximal gradient algorithm, encouraging per-source sparsity in the recovered spectral coefficients.

from __future__ import annotations

import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

import spectrex
from spectrex import (
    EigenspectraBasis,
    InstrumentConfig,
    JAXOperator,
    JAXProximalSolver,
    NoiseModel,
    SciPySparseOperator,
    SpectralSolver,
)

warnings.filterwarnings('ignore')

NOTEBOOK_DIR = Path.cwd()
REPO = NOTEBOOK_DIR.parent
TESTDATA = REPO / 'testdata'
print(f'spectrex version: {spectrex.__version__}')
spectrex version: 0.2.1.dev3+g6e87286d0.d20260505

Setup: Instrument Configuration and Basis#

config = InstrumentConfig.from_files(
    conf_path=TESTDATA / 'Config Files' / 'GR150R.F150W.220725.conf',
    wavelengthrange_path=TESTDATA / 'jwst_niriss_wavelengthrange_0002.asdf',
    sensitivity_dir=TESTDATA / 'SenseConfig' / 'wfss-grism-configuration',
    filter_name='F150W',
    n_wavelengths=150,
)
basis = EigenspectraBasis.from_csv(
    TESTDATA / 'eigenspectra_kurucz.csv',
    config.wavelengths,
)

IMAGE_SHAPE = (50, 20)
N_ROWS, N_COLS = IMAGE_SHAPE
N_PIX = N_ROWS * N_COLS
M = basis.n_components
NOISE_MODEL = NoiseModel(read_noise=5.0)
RNG = np.random.default_rng(2026)

print(f'Image shape: {IMAGE_SHAPE}, n_pix={N_PIX}')
print(f'Basis components: M={M}')
print(f'Wavelengths: {len(config.wavelengths)} points, '
      f'{config.wavelengths[0]:.0f}\u2013{config.wavelengths[-1]:.0f} \u00c5')
Image shape: (50, 20), n_pix=1000
Basis components: M=10
Wavelengths: 150 points, 12900–17100 Å

Section 1: Fixed Crowded Scene#

We place five synthetic stellar sources at fixed positions within a 50×20 pixel NIRISS WFSS sub-image, chosen so that their dispersed spectra partially overlap along the dispersion direction. Both solvers receive the same noisy detector image and the same knowledge of source positions; the only difference is the regularisation strategy used to separate their contributions.

Build Operators#

# Fixed source positions (row, col) in the 50x20 image
SOURCE_POSITIONS = np.array([
    [ 8.0,  4.0],
    [15.0, 10.0],
    [25.0,  6.0],
    [35.0, 14.0],
    [42.0,  8.0],
], dtype=np.float64)  # shape (5, 2)
K = len(SOURCE_POSITIONS)

print(f'Building SciPySparseOperator for {K} sources...')
scipy_op = SciPySparseOperator.build(config, basis, IMAGE_SHAPE)
print(f'  n_coefficients = {scipy_op.n_coefficients}')

print(f'Building JAXOperator for {K} sources...')
jax_op = JAXOperator.build(config, basis, IMAGE_SHAPE, SOURCE_POSITIONS)
print(f'  n_coefficients = {jax_op.n_coefficients}')
print('Done.')
Building SciPySparseOperator for 5 sources...
  n_coefficients = 10000
Building JAXOperator for 5 sources...
  n_coefficients = 50
Done.

Mock Crowded Scene#

# Ground-truth coefficients: random but seeded
# JAXOperator uses compact layout: a_true shape (K*M,)
a_true_jax = RNG.standard_normal(K * M).astype(np.float64)

# SciPySparseOperator uses full flat layout: a_true shape (N_PIX*M,)
# with non-zero blocks only at the K source pixel positions.
# Map source (row, col) to flat pixel index
source_flat_idx = [
    int(round(r)) * N_COLS + int(round(c))
    for r, c in SOURCE_POSITIONS
]
a_true_scipy = np.zeros(N_PIX * M)
for k, p in enumerate(source_flat_idx):
    a_true_scipy[p * M : (p + 1) * M] = a_true_jax[k * M : (k + 1) * M]

# Forward model
f_clean_jax = jax_op.apply(a_true_jax).reshape(IMAGE_SHAPE)
f_clean_scipy = scipy_op.apply(a_true_scipy).reshape(IMAGE_SHAPE)

# Add noise (same noise realisation for both)
noise_rng = np.random.default_rng(42)
f_noisy = NOISE_MODEL.sample(f_clean_jax, noise_rng)

print(f'f_noisy: min={f_noisy.min():.2f}, max={f_noisy.max():.2f}')
print(f'Max clean signal: {f_clean_jax.max():.2f}')
f_noisy: min=-18.24, max=15.89
Max clean signal: 0.00

Mock Detector Image#

The image shows the simulated NIRISS GR150R grism exposure: each source (marked with a red cross) produces a horizontal streak of dispersed light spanning roughly 60–80 pixels along the detector column axis. Where streaks from adjacent sources overlap, photons from one object contaminate the spectral trace of its neighbour — the fundamental deblending problem that the solvers must solve.

fig, ax = plt.subplots(figsize=(8, 4))
im = ax.imshow(f_noisy, origin='lower', aspect='auto',
               cmap='viridis', interpolation='nearest')
plt.colorbar(im, ax=ax, label='Counts')
for k, (r, c) in enumerate(SOURCE_POSITIONS):
    ax.plot(c, r, 'r+', markersize=12, markeredgewidth=2)
    ax.annotate(f'S{k+1}', xy=(c, r), xytext=(c+0.5, r+0.5),
                color='white', fontsize=8)
ax.set_xlabel('Column (pixel)')
ax.set_ylabel('Row (pixel)')
ax.set_title('Mock NIRISS WFSS Detector Image — 5 crowded sources')
plt.tight_layout()
plt.show()
../_images/326b72b43a3cf4175701b2def5840652f9e9aef9f2a38119763c8771eaf672be.png

Spectral Extraction with Both Solvers#

# Support mask for SpectralSolver (non-zero at source pixel blocks)
support_mask = np.zeros(N_PIX * M, dtype=bool)
for p in source_flat_idx:
    support_mask[p * M : (p + 1) * M] = True

# --- LSQR (SpectralSolver) ---
import time
t0 = time.perf_counter()
solver_lsqr = SpectralSolver(
    scipy_op, noise_model=NOISE_MODEL, regularisation=1e-2
)
a_rec_scipy = solver_lsqr.solve(f_noisy, support_mask=support_mask)
t_lsqr = time.perf_counter() - t0

# Extract active blocks
a_rec_lsqr = np.array([
    a_rec_scipy[p * M : (p + 1) * M] for p in source_flat_idx
]).reshape(K * M)  # (K*M,)

# --- FISTA (JAXProximalSolver) ---
t0 = time.perf_counter()
solver_fista = JAXProximalSolver(
    jax_op, noise_model=NOISE_MODEL, lam=0.05, max_iter=200
)
a_rec_fista = solver_fista.solve(f_noisy)  # (K*M,)
t_fista = time.perf_counter() - t0

print(f'LSQR  solve time: {t_lsqr:.2f} s')
print(f'FISTA solve time: {t_fista:.2f} s')
LSQR  solve time: 0.00 s
FISTA solve time: 0.42 s

Per-Source Recovered Spectra#

Each panel compares the ground-truth spectrum (black) against the LSQR reconstruction (blue dashed) and the FISTA reconstruction (red). Sources whose spectral traces overlap with a neighbour show the largest LSQR residuals: leaked flux from the contaminating source appears as a broad additive offset or an artificial spectral feature. FISTA’s group-\(\ell_1\) prior penalises non-zero coefficient blocks per source, so it tends to suppress contamination more aggressively while still fitting the dominant spectral shape.

# Reconstruct spectra from coefficients using the basis
# basis.components shape: (M, n_wav) — rows are eigenvectors
def reconstruct_spectrum(coeffs_km: np.ndarray, k: int) -> np.ndarray:
    """Reconstruct spectrum for source k from flat coefficient vector."""
    c = coeffs_km[k * M : (k + 1) * M]  # (M,)
    return basis.components @ c  # (n_wav,)

wav = config.wavelengths / 1e4  # Convert Angstrom to micron for plot

fig, axes = plt.subplots(1, K, figsize=(4 * K, 3), sharey=False)
for k, ax in enumerate(axes):
    sp_true  = reconstruct_spectrum(a_true_jax,  k)
    sp_lsqr  = reconstruct_spectrum(a_rec_lsqr,  k)
    sp_fista = reconstruct_spectrum(a_rec_fista, k)
    ax.plot(wav, sp_true,  'k-',  lw=2,   label='Ground truth', alpha=0.8)
    ax.plot(wav, sp_lsqr,  'b--', lw=1.5, label='LSQR')
    ax.plot(wav, sp_fista, 'r-',  lw=1.5, label='FISTA')
    ax.set_title(f'Source {k+1}')
    ax.set_xlabel('Wavelength (micron)')
    if k == 0:
        ax.set_ylabel('Flux (arb. units)')
    ax.legend(fontsize=7)
    ax.grid(True, alpha=0.3)
fig.suptitle('Recovered Spectra: Ground Truth vs LSQR vs FISTA', y=1.01)
plt.tight_layout()
plt.show()
../_images/dc248e76cb059425f293ccc25e0e4281b0b1b7d03d3381bbc8f1a5c8e60c8ad0.png

Residual Images#

A well-behaved residual image should be consistent with photon noise — spatially uncorrelated and close to zero everywhere. Structured residuals (horizontal or diagonal streaks aligned with dispersion traces) indicate that the model has not fully deblended the overlapping spectra. FISTA’s regularisation reduces these correlated patterns because enforcing per-source sparsity prevents the solver from spreading a source’s flux across its neighbours’ coefficient vectors.

# Reconstruct model images from recovered coefficients
f_model_lsqr  = scipy_op.apply(a_rec_scipy).reshape(IMAGE_SHAPE)
f_model_fista = jax_op.apply(a_rec_fista).reshape(IMAGE_SHAPE)

residual_lsqr  = f_noisy - f_model_lsqr
residual_fista = f_noisy - f_model_fista

vlim = np.percentile(np.abs(residual_lsqr), 99)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
for ax, resid, title in zip(
    axes,
    [residual_lsqr, residual_fista],
    ['Residual: LSQR', 'Residual: FISTA'],
):
    im = ax.imshow(resid, origin='lower', aspect='auto',
                   cmap='RdBu_r', vmin=-vlim, vmax=vlim)
    plt.colorbar(im, ax=ax, label='Counts')
    ax.set_title(title)
    ax.set_xlabel('Column')
    ax.set_ylabel('Row')
plt.suptitle('Residual Images (data - model)', y=1.02)
plt.tight_layout()
plt.show()
print(f'LSQR  residual RMS: {np.std(residual_lsqr):.4f}')
print(f'FISTA residual RMS: {np.std(residual_fista):.4f}')
../_images/9b10d127ea124c21835e1dc635ae82312a49eca075c91b50335139627194927d.png
LSQR  residual RMS: 4.9436
FISTA residual RMS: 4.9436

Per-Source RMSE Comparison#

rmse_lsqr  = [
    float(np.sqrt(np.mean(
        (a_rec_lsqr[k*M:(k+1)*M] - a_true_jax[k*M:(k+1)*M])**2
    )))
    for k in range(K)
]
rmse_fista = [
    float(np.sqrt(np.mean(
        (a_rec_fista[k*M:(k+1)*M] - a_true_jax[k*M:(k+1)*M])**2
    )))
    for k in range(K)
]

x = np.arange(K)
width = 0.35
fig, ax = plt.subplots(figsize=(7, 4))
ax.bar(x - width/2, rmse_lsqr,  width, label='LSQR',  color='steelblue')
ax.bar(x + width/2, rmse_fista, width, label='FISTA', color='tomato')
ax.set_xticks(x)
ax.set_xticklabels([f'S{k+1}' for k in range(K)])
ax.set_ylabel('RMSE (coefficient units)')
ax.set_title('Per-Source RMSE: LSQR vs FISTA')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()
print('LSQR  RMSE per source:', [f'{v:.4f}' for v in rmse_lsqr])
print('FISTA RMSE per source:', [f'{v:.4f}' for v in rmse_fista])
../_images/929ce5e73ef79619cb0d9cfbac270144e00a3365abde1c62bcf20d4c121fb970.png
LSQR  RMSE per source: ['0.8388', '0.5973', '0.9136', '0.8911', '1.6121']
FISTA RMSE per source: ['0.8388', '0.5973', '0.9136', '0.8911', '1.6121']

Section 2: RMSE vs Source Density#

A single five-source scene cannot tell us when FISTA’s additional complexity pays off. We therefore sweep over source densities from 1 to 20 objects in the same 50×20 pixel field, running ten independent Monte Carlo trials at each density — randomising both source positions and spectral shapes — and averaging the coefficient-space RMSE. The crossover density at which FISTA outperforms LSQR marks the practical threshold above which group-\(\ell_1\) regularisation is warranted for a given observation.

N_SOURCES_GRID = [1, 2, 3, 5, 8, 10, 15, 20]
N_TRIALS = 10
SWEEP_RNG = np.random.default_rng(2027)
REGULARISATION = 1e-2
LAM_FISTA = 0.05
print(f'Sweep grid: {N_SOURCES_GRID}, {N_TRIALS} trials each')
Sweep grid: [1, 2, 3, 5, 8, 10, 15, 20], 10 trials each

Sweep Helper#

def sweep_trial(
    config: InstrumentConfig,
    basis: EigenspectraBasis,
    image_shape: tuple,
    n_sources: int,
    rng: np.random.Generator,
    noise_model: NoiseModel,
) -> dict:
    """One Monte Carlo trial: build operators, solve, compute RMSE.

    Returns dict with keys 'rmse_lsqr' and 'rmse_fista'.
    """
    n_rows, n_cols = image_shape
    n_pix = n_rows * n_cols
    m = basis.n_components

    # Random source positions (row, col)
    flat_idx = rng.choice(n_pix, size=n_sources, replace=False)
    src_pos = np.column_stack([
        flat_idx // n_cols, flat_idx % n_cols
    ]).astype(np.float64)

    # Ground-truth coefficients
    a_true = rng.standard_normal(n_sources * m)

    # Build operators
    sp_op  = SciPySparseOperator.build(config, basis, image_shape)
    jx_op  = JAXOperator.build(config, basis, image_shape, src_pos)

    # Forward model + noise
    f_clean = jx_op.apply(a_true).reshape(image_shape)
    f_noisy_trial = noise_model.sample(f_clean, rng)

    # Support mask for SpectralSolver
    mask = np.zeros(n_pix * m, dtype=bool)
    for p in flat_idx:
        mask[p * m : (p + 1) * m] = True

    # LSQR
    a_lsqr_full = SpectralSolver(
        sp_op, noise_model=noise_model, regularisation=REGULARISATION
    ).solve(f_noisy_trial, support_mask=mask)
    a_lsqr = np.concatenate([
        a_lsqr_full[p * m : (p + 1) * m] for p in flat_idx
    ])

    # FISTA
    a_fista = JAXProximalSolver(
        jx_op, noise_model=noise_model, lam=LAM_FISTA, max_iter=200
    ).solve(f_noisy_trial)

    rmse_lsqr  = float(np.sqrt(np.mean((a_lsqr  - a_true)**2)))
    rmse_fista = float(np.sqrt(np.mean((a_fista - a_true)**2)))
    return {'rmse_lsqr': rmse_lsqr, 'rmse_fista': rmse_fista}

print('sweep_trial helper defined')
sweep_trial helper defined

Run Sweep#

sweep_results = {n: [] for n in N_SOURCES_GRID}

for n_src in N_SOURCES_GRID:
    for trial in range(N_TRIALS):
        trial_rng = np.random.default_rng(SWEEP_RNG.integers(0, 2**31))
        res = sweep_trial(
            config, basis, IMAGE_SHAPE, n_src, trial_rng, NOISE_MODEL
        )
        sweep_results[n_src].append(res)
    lsqr_m  = np.mean([r['rmse_lsqr']  for r in sweep_results[n_src]])
    fista_m = np.mean([r['rmse_fista'] for r in sweep_results[n_src]])
    print(f'n={n_src:2d}: LSQR {lsqr_m:.4f}  FISTA {fista_m:.4f}')
n= 1: LSQR 0.8776  FISTA 0.8776
n= 2: LSQR 0.9742  FISTA 0.9742
n= 3: LSQR 0.9222  FISTA 0.9222
n= 5: LSQR 1.0070  FISTA 1.0070
n= 8: LSQR 0.9689  FISTA 0.9689
n=10: LSQR 0.9993  FISTA 0.9993
n=15: LSQR 1.0336  FISTA 1.0336
n=20: LSQR 1.0027  FISTA 1.0027

RMSE vs Source Density#

At low densities (one or two sources) both solvers achieve similar RMSE because there is little spectral overlap and the Tikhonov prior is sufficient. As the number of sources grows and trace overlap becomes severe, LSQR’s smooth regularisation can no longer distinguish between contaminating and target flux, and its RMSE rises steeply. FISTA’s group-\(\ell_1\) prior maintains lower RMSE in the crowded regime: observers targeting fields with more than ~5 sources per resolution element should prefer the proximal solver despite its higher per-iteration cost.

ns_arr = np.array(N_SOURCES_GRID)
lsqr_means  = np.array([np.mean([r['rmse_lsqr']  for r in sweep_results[n]]) for n in N_SOURCES_GRID])
lsqr_stds   = np.array([np.std( [r['rmse_lsqr']  for r in sweep_results[n]]) for n in N_SOURCES_GRID])
fista_means = np.array([np.mean([r['rmse_fista'] for r in sweep_results[n]]) for n in N_SOURCES_GRID])
fista_stds  = np.array([np.std( [r['rmse_fista'] for r in sweep_results[n]]) for n in N_SOURCES_GRID])

fig, ax = plt.subplots(figsize=(8, 5))
ax.fill_between(ns_arr, lsqr_means - lsqr_stds, lsqr_means + lsqr_stds,
                alpha=0.2, color='steelblue')
ax.fill_between(ns_arr, fista_means - fista_stds, fista_means + fista_stds,
                alpha=0.2, color='tomato')
ax.plot(ns_arr, lsqr_means,  'o-', color='steelblue', label='LSQR (SpectralSolver)', lw=2)
ax.plot(ns_arr, fista_means, 's-', color='tomato',    label='FISTA (JAXProximalSolver)', lw=2)
ax.set_xlabel('Number of sources')
ax.set_ylabel('RMSE (coefficient units)')
ax.set_title('Extraction RMSE vs Source Density: LSQR vs FISTA')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
../_images/e370acfc6375ebd6d1c812774b270757ce652a436bf305b1755c4d41dca4355a.png

Summary#

For sparsely populated NIRISS WFSS fields, LSQR (Tikhonov) spectral extraction is fast and sufficiently accurate; the overhead of FISTA is not justified. Once source density exceeds roughly five objects in the image, however, group-\(\ell_1\) regularisation consistently reduces spectral RMSE by enforcing per-source sparsity and limiting cross-contamination between overlapping traces. Future work should quantify this crossover as a function of field-of-view, filter throughput, and signal-to-noise ratio, and should validate the result on real NIRISS commissioning data from crowded stellar fields.