Forward Operators#

The forward operator \(H\) encodes the mapping from source spectral coefficients \(\mathbf{a}\) to detector pixel values \(\mathbf{f} = H\mathbf{a}\). It is built once from an InstrumentConfig and an EigenspectraBasis, then passed to a solver.

spectrex ships two concrete operator implementations, both satisfying ForwardOperatorProtocol. Any class that implements apply and apply_adjoint with matching signatures can be substituted transparently.

ForwardOperatorProtocol#

ForwardOperatorProtocol is a runtime_checkable() Protocol that defines the interface every operator must satisfy. The two required methods are:

  • apply(a) — computes \(H\mathbf{a}\); takes a coefficient array of shape (K, M) and returns a detector array of shape image_shape.

  • apply_adjoint(f) — computes \(H^\top \mathbf{f}\); the transpose operation used by iterative solvers.

Both n_coefficients and image_shape must be properties so that solvers can allocate correctly sized buffers without inspecting the operator internals.

SciPySparseOperator#

SciPySparseOperator stores \(H\) as a CSR sparse matrix. Build it with build(), which loops over sources and orders, calling get_trace() and the basis integrated_weights() to fill the matrix entries. The resulting matrix is K-independent in structure — all sources and all PCA components are packed into a flat coefficient vector.

This operator is best suited to sparse fields (fewer than ~20 sources) or to exploratory analysis where quick iteration matters more than peak throughput. save() / load() serialise and restore the matrix from disk so expensive builds need not be repeated.

JAXOperator#

JAXOperator uses a compact trace layout that avoids storing a full \(N_\text{pix} \times K\) matrix:

  • trace_indices[K, O, L] — int32 array of detector pixel indices, shape (n_sources, n_orders, n_lambda)

  • weights[O, L, M] — float32 array of sensitivity-weighted basis values, shape (n_orders, n_lambda, n_components)

Pixel index arithmetic is vectorised by JAX and JIT-compiled, so apply() and apply_adjoint() run at near-hardware throughput on both CPU and GPU. Out-of-bounds wavelengths are routed to a ghost pixel at index n_pix; the ghost pixel value is discarded before returning the detector image, ensuring no index-out-of-range errors from JAX’s jax.numpy.at() semantics.

n_active() reports how many sources have at least one valid trace pixel; n_components() reports the number of PCA components.

H-matrix Schematic#

The diagram below illustrates schematically how apply accumulates source coefficients onto detector pixels. Not all \(K \times N_\text{pix}\) edges are drawn; three representative source groups and six pixels are shown:

digraph hmatrix {
    rankdir=LR;
    node [fontname="Helvetica", fontsize=11];
    splines=polyline;

    // Source coefficient nodes
    a0 [label="a₀", shape=ellipse, style=filled, fillcolor="#d4edda"];
    a1 [label="a₁", shape=ellipse, style=filled, fillcolor="#d4edda"];
    a2 [label="a₂", shape=ellipse, style=filled, fillcolor="#d4edda"];

    // Pixel nodes
    p0 [label="p₀", shape=box, style=filled, fillcolor="#fff3cd"];
    p1 [label="p₁", shape=box, style=filled, fillcolor="#fff3cd"];
    p2 [label="p₂", shape=box, style=filled, fillcolor="#fff3cd"];
    p3 [label="p₃", shape=box, style=filled, fillcolor="#fff3cd"];
    p4 [label="p₄", shape=box, style=filled, fillcolor="#fff3cd"];
    p5 [label="p₅", shape=box, style=filled, fillcolor="#fff3cd"];

    // Weighted edges (representative)
    a0 -> p0 [label="w₀₀"];
    a0 -> p1 [label="w₀₁"];
    a0 -> p2 [label="w₀₂"];
    a1 -> p1 [label="w₁₁"];
    a1 -> p3 [label="w₁₃"];
    a1 -> p4 [label="w₁₄"];
    a2 -> p2 [label="w₂₂"];
    a2 -> p4 [label="w₂₄"];
    a2 -> p5 [label="w₂₅"];
}

Each edge label \(w_{kp}\) is the sensitivity-weighted sum of the PCA basis integrated over the trace of source group k at detector pixel p. apply() performs this scatter–add; apply_adjoint performs the corresponding gather.

API Reference#

class spectrex.ForwardOperatorProtocol#

Protocol for the grism dispersion operator \(H\).

Any object satisfying this protocol can be passed to SpectralSolver or JAXProximalSolver. Conformance is structural — no inheritance from this class is required. Use isinstance(obj, ForwardOperatorProtocol) to check at runtime.

image_shape: tuple[int, int]#

(n_rows, n_cols) of the detector image.

n_coefficients: int#

Total length of the flattened coefficient vector (n_sources * n_components for JAXOperator).

apply(a_tilde)#

Forward pass: \(H\,\mathbf{a}\).

Parameters:

a_tilde (numpy.ndarray) – Coefficient vector, shape (n_coefficients,).

Returns:

Flattened dispersed image, shape (n_rows * n_cols,).

Return type:

numpy.ndarray

apply_adjoint(f)#

Adjoint pass: \(H^\top \mathbf{f}\).

Parameters:

f (numpy.ndarray) – Flattened dispersed image, shape (n_rows * n_cols,).

Returns:

Coefficient vector, shape (n_coefficients,).

Return type:

numpy.ndarray

class spectrex.SciPySparseOperator(H, image_shape)[source]#

Bases: object

Grism forward operator backed by a scipy CSR sparse matrix.

Build from calibration data with build(), or load a previously cached operator with load().

Parameters:
  • H (csr_matrix) – Sparse forward matrix, shape (n_rows * n_cols, n_rows * n_cols * n_components).

  • image_shape (tuple[int, int]) – (n_rows, n_cols) of the detector image.

apply(a_tilde)[source]#

Forward pass: H @ a_tilde.

Parameters:

a_tilde (np.ndarray) – Shape (n_coefficients,).

Returns:

Shape (n_rows * n_cols,).

Return type:

np.ndarray

apply_adjoint(f)[source]#

Adjoint pass: H.T @ f.

Parameters:

f (np.ndarray) – Shape (n_rows * n_cols,).

Returns:

Shape (n_coefficients,).

Return type:

np.ndarray

classmethod build(config, basis, image_shape)[source]#

Build the sparse forward matrix H from scratch.

Parameters:
Return type:

SciPySparseOperator

Notes

Build complexity is O(n_rows * n_cols * n_wavelengths) per diffraction order. For full NIRISS (2048 × 2048) this will take minutes. Cache the result with save().

classmethod load(path)[source]#

Load a saved operator from an .npz file.

Parameters:

path (Path) – File written by save().

Return type:

SciPySparseOperator

save(path)[source]#

Save the operator to a single .npz file.

Parameters:

path (Path) – Output path. The .npz extension is added if absent.

Return type:

None

class spectrex.JAXOperator(trace_indices, weights, image_shape)[source]#

Bases: object

Grism forward operator using compact trace index storage.

Unlike SciPySparseOperator, this class never materialises a full sparse matrix. Instead it stores:

  • trace_indices[k, o, λ] — flat pixel index where source k, dispersion order o, wavelength index λ lands on the detector. Out-of-bounds wavelengths use n_pix (ghost pixel sentinel).

  • weights[o, λ, m] — shared instrument response × basis weight. Shape is independent of image size and number of sources.

Memory scales as O(K × n_orders × n_lambda) rather than O(N_pix² × M), making it tractable for full NIRISS 2048 × 2048.

Parameters:
  • trace_indices (np.ndarray) – Shape (K, n_orders, n_lambda), dtype int32. Values in [0, n_pix]; n_pix is the ghost pixel sentinel.

  • weights (np.ndarray) – Shape (n_orders, n_lambda, n_components), dtype float32.

  • image_shape (tuple[int, int]) – (n_rows, n_cols) of the detector image.

apply(a_tilde)[source]#

Forward pass: H @ a_tilde.

Parameters:

a_tilde (np.ndarray) – Coefficient vector, shape (K * M,).

Returns:

Flattened dispersed image, shape (n_rows * n_cols,).

Return type:

np.ndarray

apply_adjoint(f)[source]#

Adjoint pass: H.T @ f.

Parameters:

f (np.ndarray) – Flattened dispersed image, shape (n_rows * n_cols,).

Returns:

Coefficient vector, shape (K * M,).

Return type:

np.ndarray

classmethod build(config, basis, image_shape, source_positions)[source]#

Build from calibration data and a source catalogue.

Parameters:
  • config (InstrumentConfig)

  • basis (EigenspectraBasis)

  • image_shape (tuple[int, int]) – (n_rows, n_cols) of the detector image.

  • source_positions (np.ndarray) – Shape (K, 2) with (row, col) float positions for each source. Sub-pixel positions are accepted.

Return type:

JAXOperator

classmethod load(path)[source]#

Load a serialised operator from a .npz archive.

Parameters:

path (Path) – File written by save().

Return type:

JAXOperator

property n_active: int#

Number of active sources K.

property n_components: int#

Number of basis components M.

save(path)[source]#

Serialise to a .npz archive.

Parameters:

path (Path) – Output path. The .npz extension is added by numpy.savez() if absent.

Return type:

None

See also