Forward Operators#
The forward operator \(H\) encodes the mapping from source spectral
coefficients \(\mathbf{a}\) to detector pixel values \(\mathbf{f} = H\mathbf{a}\).
It is built once from an InstrumentConfig and an
EigenspectraBasis, then passed to a solver.
spectrex ships two concrete operator implementations, both satisfying
ForwardOperatorProtocol. Any class that implements
apply and apply_adjoint with matching signatures can be substituted
transparently.
ForwardOperatorProtocol#
ForwardOperatorProtocol is a
runtime_checkable() Protocol that defines the
interface every operator must satisfy. The two required methods are:
apply(a)— computes \(H\mathbf{a}\); takes a coefficient array of shape(K, M)and returns a detector array of shapeimage_shape.apply_adjoint(f)— computes \(H^\top \mathbf{f}\); the transpose operation used by iterative solvers.
Both n_coefficients and image_shape must be properties so that solvers
can allocate correctly sized buffers without inspecting the operator internals.
SciPySparseOperator#
SciPySparseOperator stores \(H\) as a CSR sparse
matrix. Build it with build(), which
loops over sources and orders, calling
get_trace() and the basis
integrated_weights() to fill the matrix
entries. The resulting matrix is K-independent in structure — all sources
and all PCA components are packed into a flat coefficient vector.
This operator is best suited to sparse fields (fewer than ~20 sources) or to
exploratory analysis where quick iteration matters more than peak throughput.
save() / load()
serialise and restore the matrix from disk so expensive builds need not be
repeated.
JAXOperator#
JAXOperator uses a compact trace layout that avoids
storing a full \(N_\text{pix} \times K\) matrix:
trace_indices[K, O, L]— int32 array of detector pixel indices, shape(n_sources, n_orders, n_lambda)weights[O, L, M]— float32 array of sensitivity-weighted basis values, shape(n_orders, n_lambda, n_components)
Pixel index arithmetic is vectorised by JAX and JIT-compiled, so
apply() and
apply_adjoint() run at near-hardware throughput on
both CPU and GPU. Out-of-bounds wavelengths are routed to a ghost pixel at
index n_pix; the ghost pixel value is discarded before returning the
detector image, ensuring no index-out-of-range errors from JAX’s
jax.numpy.at() semantics.
n_active() reports how many sources have at least
one valid trace pixel; n_components() reports the
number of PCA components.
H-matrix Schematic#
The diagram below illustrates schematically how apply accumulates source
coefficients onto detector pixels. Not all \(K \times N_\text{pix}\)
edges are drawn; three representative source groups and six pixels are shown:
![digraph hmatrix {
rankdir=LR;
node [fontname="Helvetica", fontsize=11];
splines=polyline;
// Source coefficient nodes
a0 [label="a₀", shape=ellipse, style=filled, fillcolor="#d4edda"];
a1 [label="a₁", shape=ellipse, style=filled, fillcolor="#d4edda"];
a2 [label="a₂", shape=ellipse, style=filled, fillcolor="#d4edda"];
// Pixel nodes
p0 [label="p₀", shape=box, style=filled, fillcolor="#fff3cd"];
p1 [label="p₁", shape=box, style=filled, fillcolor="#fff3cd"];
p2 [label="p₂", shape=box, style=filled, fillcolor="#fff3cd"];
p3 [label="p₃", shape=box, style=filled, fillcolor="#fff3cd"];
p4 [label="p₄", shape=box, style=filled, fillcolor="#fff3cd"];
p5 [label="p₅", shape=box, style=filled, fillcolor="#fff3cd"];
// Weighted edges (representative)
a0 -> p0 [label="w₀₀"];
a0 -> p1 [label="w₀₁"];
a0 -> p2 [label="w₀₂"];
a1 -> p1 [label="w₁₁"];
a1 -> p3 [label="w₁₃"];
a1 -> p4 [label="w₁₄"];
a2 -> p2 [label="w₂₂"];
a2 -> p4 [label="w₂₄"];
a2 -> p5 [label="w₂₅"];
}](../_images/graphviz-c876896d9694c82b588d12ac1f32b613bc0b7aff.png)
Each edge label \(w_{kp}\) is the sensitivity-weighted sum of the PCA
basis integrated over the trace of source group k at detector pixel p.
apply() performs this scatter–add; apply_adjoint
performs the corresponding gather.
API Reference#
- class spectrex.ForwardOperatorProtocol#
Protocol for the grism dispersion operator \(H\).
Any object satisfying this protocol can be passed to
SpectralSolverorJAXProximalSolver. Conformance is structural — no inheritance from this class is required. Useisinstance(obj, ForwardOperatorProtocol)to check at runtime.- n_coefficients: int#
Total length of the flattened coefficient vector (
n_sources * n_componentsforJAXOperator).
- apply(a_tilde)#
Forward pass: \(H\,\mathbf{a}\).
- Parameters:
a_tilde (numpy.ndarray) – Coefficient vector, shape
(n_coefficients,).- Returns:
Flattened dispersed image, shape
(n_rows * n_cols,).- Return type:
- apply_adjoint(f)#
Adjoint pass: \(H^\top \mathbf{f}\).
- Parameters:
f (numpy.ndarray) – Flattened dispersed image, shape
(n_rows * n_cols,).- Returns:
Coefficient vector, shape
(n_coefficients,).- Return type:
- class spectrex.SciPySparseOperator(H, image_shape)[source]#
Bases:
objectGrism forward operator backed by a scipy CSR sparse matrix.
Build from calibration data with
build(), or load a previously cached operator withload().- Parameters:
- apply(a_tilde)[source]#
Forward pass:
H @ a_tilde.- Parameters:
a_tilde (np.ndarray) – Shape
(n_coefficients,).- Returns:
Shape
(n_rows * n_cols,).- Return type:
np.ndarray
- apply_adjoint(f)[source]#
Adjoint pass:
H.T @ f.- Parameters:
f (np.ndarray) – Shape
(n_rows * n_cols,).- Returns:
Shape
(n_coefficients,).- Return type:
np.ndarray
- classmethod build(config, basis, image_shape)[source]#
Build the sparse forward matrix H from scratch.
- Parameters:
config (InstrumentConfig)
basis (EigenspectraBasis)
image_shape (tuple[int, int]) –
(n_rows, n_cols)of the detector image.
- Return type:
Notes
Build complexity is
O(n_rows * n_cols * n_wavelengths)per diffraction order. For full NIRISS (2048 × 2048) this will take minutes. Cache the result withsave().
- class spectrex.JAXOperator(trace_indices, weights, image_shape)[source]#
Bases:
objectGrism forward operator using compact trace index storage.
Unlike
SciPySparseOperator, this class never materialises a full sparse matrix. Instead it stores:trace_indices[k, o, λ]— flat pixel index where source k, dispersion order o, wavelength index λ lands on the detector. Out-of-bounds wavelengths usen_pix(ghost pixel sentinel).weights[o, λ, m]— shared instrument response × basis weight. Shape is independent of image size and number of sources.
Memory scales as
O(K × n_orders × n_lambda)rather thanO(N_pix² × M), making it tractable for full NIRISS 2048 × 2048.- Parameters:
- apply(a_tilde)[source]#
Forward pass:
H @ a_tilde.- Parameters:
a_tilde (np.ndarray) – Coefficient vector, shape
(K * M,).- Returns:
Flattened dispersed image, shape
(n_rows * n_cols,).- Return type:
np.ndarray
- apply_adjoint(f)[source]#
Adjoint pass:
H.T @ f.- Parameters:
f (np.ndarray) – Flattened dispersed image, shape
(n_rows * n_cols,).- Returns:
Coefficient vector, shape
(K * M,).- Return type:
np.ndarray
- classmethod build(config, basis, image_shape, source_positions)[source]#
Build from calibration data and a source catalogue.
- Parameters:
config (InstrumentConfig)
basis (EigenspectraBasis)
image_shape (tuple[int, int]) –
(n_rows, n_cols)of the detector image.source_positions (np.ndarray) – Shape
(K, 2)with(row, col)float positions for each source. Sub-pixel positions are accepted.
- Return type:
- classmethod load(path)[source]#
Load a serialised operator from a
.npzarchive.- Parameters:
path (Path) – File written by
save().- Return type:
- save(path)[source]#
Serialise to a
.npzarchive.- Parameters:
path (Path) – Output path. The
.npzextension is added bynumpy.savez()if absent.- Return type:
None
See also
Instrument Configuration & Spectral Basis — building the inputs to the operator
Solvers — passing the operator to a solver
Computational Benchmarks: JAXOperator vs SciPySparseOperator — memory footprint and runtime benchmarks for both operators