autogalaxy.TransformerNUFFT#
- class TransformerNUFFT[source]#
Bases:
objectJAX-native Non-Uniform FFT for image -> visibilities, backed by nufftax.
This is the default TransformerNUFFT in PyAutoArray. It uses the nufftax library (https://github.com/GragasLab/nufftax), a pure-JAX NUFFT implementation that supports jax.jit, jax.grad, and jax.vmap. It replaces the legacy TransformerNUFFTPyNUFFT (which wraps the non-differentiable pynufft library) as the default backend.
Convention recipe (matches TransformerDFT to ~1e-13 relative across odd/even/non-square image sizes):
image_flipped = image[::-1, :] x = 2 * pi * u_lambda * pixel_scale_rad y = 2 * pi * v_lambda * pixel_scale_rad offset_x = 0.5 if N_x is even else 0.0 offset_y = 0.5 if N_y is even else 0.0 shift = exp(-i * (offset_x * x + offset_y * y)) visibilities = nufftax.nufft2d2(x, y, image_flipped, eps, -1) * shift
The shift factor is the half-pixel correction between autoarray’s grid centre at index (N - 1) / 2 and nufftax’s mode-0 at index N // 2; pynufft applies this internally, nufftax does not.
- Parameters:
uv_wavelengths (
ndarray) – The (u, v) coordinates of the measured visibilities in wavelengths, shape (n_vis, 2).real_space_mask (
Mask2D) – The 2D mask defining the real-space image grid.eps (
float) – Requested NUFFT precision passed to nufftax. Defaults to 1e-12 (effectively machine precision); relax to 1e-9 or 1e-6 for faster execution if marginal accuracy is acceptable.chunk_size (
Optional[int]) – If set to a positive integer, the forward and adjoint NUFFT calls split the visibility axis into chunks of this size and iterate (viajax.lax.scanon the JAX path, a Python loop on the numpy path). This caps the nufftax gather-buffer allocation (~``2 * chunk_size * nspread^2 * dtype_size``) at the cost of per-chunk overhead. Required for visibility counts above ~5M on a 40-80 GB GPU. IfNone(default), a single one-shot call is used — preserves existing behaviour for small-N callers (sma-class datasets).xp – Accepted for signature compatibility with the legacy class; not stored. The active backend is selected per-call via the xp argument to visibilities_from / image_from.
- grid#
The real-space pixel grid in radians (computed from the mask).
- total_visibilities#
Number of measured visibilities.
- total_image_pixels#
Number of unmasked pixels in the image grid.
- adjoint_scaling#
Scaling factor available for callers who want to apply an optional normalisation to the adjoint output. Provided for parity with the legacy class.
- __init__(uv_wavelengths, real_space_mask, eps=1e-12, chunk_size=None, xp=<module 'numpy' from '/home/docs/checkouts/readthedocs.org/user_builds/pyautogalaxy/envs/latest/lib/python3.11/site-packages/numpy/__init__.py'>, **kwargs)[source]#
JAX-native Non-Uniform FFT for image -> visibilities, backed by nufftax.
This is the default TransformerNUFFT in PyAutoArray. It uses the nufftax library (https://github.com/GragasLab/nufftax), a pure-JAX NUFFT implementation that supports jax.jit, jax.grad, and jax.vmap. It replaces the legacy TransformerNUFFTPyNUFFT (which wraps the non-differentiable pynufft library) as the default backend.
Convention recipe (matches TransformerDFT to ~1e-13 relative across odd/even/non-square image sizes):
image_flipped = image[::-1, :] x = 2 * pi * u_lambda * pixel_scale_rad y = 2 * pi * v_lambda * pixel_scale_rad offset_x = 0.5 if N_x is even else 0.0 offset_y = 0.5 if N_y is even else 0.0 shift = exp(-i * (offset_x * x + offset_y * y)) visibilities = nufftax.nufft2d2(x, y, image_flipped, eps, -1) * shift
The shift factor is the half-pixel correction between autoarray’s grid centre at index (N - 1) / 2 and nufftax’s mode-0 at index N // 2; pynufft applies this internally, nufftax does not.
- Parameters:
uv_wavelengths (
ndarray) – The (u, v) coordinates of the measured visibilities in wavelengths, shape (n_vis, 2).real_space_mask (
Mask2D) – The 2D mask defining the real-space image grid.eps (
float) – Requested NUFFT precision passed to nufftax. Defaults to 1e-12 (effectively machine precision); relax to 1e-9 or 1e-6 for faster execution if marginal accuracy is acceptable.chunk_size (
Optional[int]) – If set to a positive integer, the forward and adjoint NUFFT calls split the visibility axis into chunks of this size and iterate (viajax.lax.scanon the JAX path, a Python loop on the numpy path). This caps the nufftax gather-buffer allocation (~``2 * chunk_size * nspread^2 * dtype_size``) at the cost of per-chunk overhead. Required for visibility counts above ~5M on a 40-80 GB GPU. IfNone(default), a single one-shot call is used — preserves existing behaviour for small-N callers (sma-class datasets).xp – Accepted for signature compatibility with the legacy class; not stored. The active backend is selected per-call via the xp argument to visibilities_from / image_from.
- grid#
The real-space pixel grid in radians (computed from the mask).
- total_visibilities#
Number of measured visibilities.
- total_image_pixels#
Number of unmasked pixels in the image grid.
- adjoint_scaling#
Scaling factor available for callers who want to apply an optional normalisation to the adjoint output. Provided for parity with the legacy class.
Methods
__init__(uv_wavelengths, real_space_mask[, ...])JAX-native Non-Uniform FFT for image -> visibilities, backed by nufftax.
image_from(visibilities[, ...])Adjoint NUFFT: visibilities -> real-space (dirty) image.
transform_mapping_matrix(mapping_matrix[, xp])Apply the forward NUFFT to each column of a mapping matrix.
visibilities_from(image[, xp])Forward NUFFT: real-space image -> visibilities at the configured uv points.
- __init__(uv_wavelengths, real_space_mask, eps=1e-12, chunk_size=None, xp=<module 'numpy' from '/home/docs/checkouts/readthedocs.org/user_builds/pyautogalaxy/envs/latest/lib/python3.11/site-packages/numpy/__init__.py'>, **kwargs)[source]#
JAX-native Non-Uniform FFT for image -> visibilities, backed by nufftax.
This is the default TransformerNUFFT in PyAutoArray. It uses the nufftax library (https://github.com/GragasLab/nufftax), a pure-JAX NUFFT implementation that supports jax.jit, jax.grad, and jax.vmap. It replaces the legacy TransformerNUFFTPyNUFFT (which wraps the non-differentiable pynufft library) as the default backend.
Convention recipe (matches TransformerDFT to ~1e-13 relative across odd/even/non-square image sizes):
image_flipped = image[::-1, :] x = 2 * pi * u_lambda * pixel_scale_rad y = 2 * pi * v_lambda * pixel_scale_rad offset_x = 0.5 if N_x is even else 0.0 offset_y = 0.5 if N_y is even else 0.0 shift = exp(-i * (offset_x * x + offset_y * y)) visibilities = nufftax.nufft2d2(x, y, image_flipped, eps, -1) * shift
The shift factor is the half-pixel correction between autoarray’s grid centre at index (N - 1) / 2 and nufftax’s mode-0 at index N // 2; pynufft applies this internally, nufftax does not.
- Parameters:
uv_wavelengths (
ndarray) – The (u, v) coordinates of the measured visibilities in wavelengths, shape (n_vis, 2).real_space_mask (
Mask2D) – The 2D mask defining the real-space image grid.eps (
float) – Requested NUFFT precision passed to nufftax. Defaults to 1e-12 (effectively machine precision); relax to 1e-9 or 1e-6 for faster execution if marginal accuracy is acceptable.chunk_size (
Optional[int]) – If set to a positive integer, the forward and adjoint NUFFT calls split the visibility axis into chunks of this size and iterate (viajax.lax.scanon the JAX path, a Python loop on the numpy path). This caps the nufftax gather-buffer allocation (~``2 * chunk_size * nspread^2 * dtype_size``) at the cost of per-chunk overhead. Required for visibility counts above ~5M on a 40-80 GB GPU. IfNone(default), a single one-shot call is used — preserves existing behaviour for small-N callers (sma-class datasets).xp – Accepted for signature compatibility with the legacy class; not stored. The active backend is selected per-call via the xp argument to visibilities_from / image_from.
- grid#
The real-space pixel grid in radians (computed from the mask).
- total_visibilities#
Number of measured visibilities.
- total_image_pixels#
Number of unmasked pixels in the image grid.
- adjoint_scaling#
Scaling factor available for callers who want to apply an optional normalisation to the adjoint output. Provided for parity with the legacy class.
- visibilities_from(image, xp=<module 'numpy' from '/home/docs/checkouts/readthedocs.org/user_builds/pyautogalaxy/envs/latest/lib/python3.11/site-packages/numpy/__init__.py'>)[source]#
Forward NUFFT: real-space image -> visibilities at the configured uv points.
For numpy callers (xp=np) the result is materialised back to numpy before being wrapped in Visibilities. For JAX callers (xp=jnp) the result stays as a jax.Array so it can flow through jax.jit / jax.grad / jax.vmap without device round-trips.
- image_from(visibilities, use_adjoint_scaling=False, xp=<module 'numpy' from '/home/docs/checkouts/readthedocs.org/user_builds/pyautogalaxy/envs/latest/lib/python3.11/site-packages/numpy/__init__.py'>)[source]#
Adjoint NUFFT: visibilities -> real-space (dirty) image.
Implemented as nufftax.nufft2d1 with conj(shift) applied to the visibilities and a final row-flip to return to autoarray’s native orientation. The real part is taken to discard imaginary residue, matching the legacy class’ behaviour.
Note that this is the mathematical adjoint of visibilities_from, with no kernel deconvolution applied. The dirty image therefore differs in absolute scale from the legacy TransformerNUFFTPyNUFFT adjoint (which applies pynufft’s internal IFFT and kernel deconvolution). The structure of the dirty image is the same, and the values match TransformerDFT.image_from exactly.
use_adjoint_scaling is accepted for API compatibility with the legacy class and is otherwise unused (the nufftax adjoint is already the mathematical adjoint; no extra normalisation is needed). This matches TransformerDFT.image_from semantics so the sparse-operator path is scale-consistent across both transformers.
- transform_mapping_matrix(mapping_matrix, xp=<module 'numpy' from '/home/docs/checkouts/readthedocs.org/user_builds/pyautogalaxy/envs/latest/lib/python3.11/site-packages/numpy/__init__.py'>)[source]#
Apply the forward NUFFT to each column of a mapping matrix.
All columns are scattered into a single batched native-shape image of shape
(n_src, N_y, N_x)and passed through nufft2d2 in one call (nufft2d2 supports batchedf). This avoids the per-column Python loop that, underjax.jit, would unroll inton_srcseparate NUFFT invocations and blow up the JIT graph for pixelization-heavy fits (notably double-source-plane).