From cb0546dd143d79171105c3d1f2a70e2abb81b860 Mon Sep 17 00:00:00 2001 From: Hitendrasinh Rathod Date: Sun, 21 Dec 2025 16:09:27 +0000 Subject: [PATCH 1/5] FEAT: Add 3D Radial Fourier Transform for medical image frequency analysis - Implement RadialFourier3D transform for radial frequency analysis - Add RadialFourierFeatures3D for multi-scale feature extraction - Include comprehensive tests (20/20 passing) - Support for magnitude, phase, and complex outputs - Handle anisotropic resolution in medical imaging - Fix numpy compatibility and spatial dimension handling Signed-off-by: Hitendrasinh Rathod Signed-off-by: Hitendrasinh Rathod --- monai/transforms/__init__.py | 4 +- monai/transforms/signal/__init__.py | 7 + monai/transforms/signal/radial_fourier.py | 350 ++++++++++++++++++++++ tests/test_radial_fourier.py | 196 ++++++++++++ tests/transforms/signal/__init__.py | 0 5 files changed, 555 insertions(+), 2 deletions(-) create mode 100644 monai/transforms/signal/radial_fourier.py create mode 100644 tests/test_radial_fourier.py create mode 100644 tests/transforms/signal/__init__.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 3fd33b76da..b2dcb965e3 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -376,9 +376,9 @@ SignalRandAddSquarePulsePartial, SignalRandDrop, SignalRandScale, - SignalRandShift, - SignalRemoveFrequency, + SignalRemoveFrequency ) +from .signal import RadialFourier3D, RadialFourierFeatures3D from .signal.dictionary import SignalFillEmptyd, SignalFillEmptyD, SignalFillEmptyDict from .smooth_field.array import ( RandSmoothDeform, diff --git a/monai/transforms/signal/__init__.py b/monai/transforms/signal/__init__.py index 1e97f89407..5ed71ccb0e 100644 --- a/monai/transforms/signal/__init__.py +++ b/monai/transforms/signal/__init__.py @@ -8,3 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Signal processing transforms for medical imaging. +""" + +from .radial_fourier import RadialFourier3D, RadialFourierFeatures3D + +__all__ = ["RadialFourier3D", "RadialFourierFeatures3D"] diff --git a/monai/transforms/signal/radial_fourier.py b/monai/transforms/signal/radial_fourier.py new file mode 100644 index 0000000000..e58aefe7e5 --- /dev/null +++ b/monai/transforms/signal/radial_fourier.py @@ -0,0 +1,350 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +3D Radial Fourier Transform for medical imaging data. +""" + +from __future__ import annotations + +import math +from typing import Optional, Union + +from collections.abc import Sequence + +import numpy as np +import torch +from torch.fft import fftn, fftshift, ifftn, ifftshift + +from monai.config import NdarrayOrTensor +from monai.transforms.transform import Transform +from monai.utils import convert_data_type, optional_import + +# Optional imports for type checking +spatial, _ = optional_import("monai.utils", name="spatial") + + +class RadialFourier3D(Transform): + """ + Computes the 3D Radial Fourier Transform of medical imaging data. + + This transform converts 3D medical images into radial frequency domain representations, + which is particularly useful for handling anisotropic resolution common in medical scans + (e.g., different resolution in axial vs coronal planes). + + The radial transform provides rotation-invariant frequency analysis and can help + normalize frequency representations across datasets with different acquisition parameters. + + Args: + normalize: if True, normalize the output by the number of voxels. + return_magnitude: if True, return magnitude of the complex result. + return_phase: if True, return phase of the complex result. + radial_bins: number of radial bins for frequency aggregation. If None, returns full 3D spectrum. + max_frequency: maximum normalized frequency to include (0.0 to 1.0). + spatial_dims: spatial dimensions to apply transform to. Default is last three dimensions. + + Returns: + Radial Fourier transform of input data. Shape depends on parameters: + - If radial_bins is None: complex tensor of same spatial shape as input + - If radial_bins is set: real tensor of shape (radial_bins,) for magnitude/phase + + Example: + >>> transform = RadialFourier3D(radial_bins=64, return_magnitude=True) + >>> image = torch.randn(1, 128, 128, 96) # Batch, Height, Width, Depth + >>> result = transform(image) # Shape: (1, 64) + """ + + def __init__( + self, + normalize: bool = True, + return_magnitude: bool = True, + return_phase: bool = False, + radial_bins: Optional[int] = None, + max_frequency: float = 1.0, + spatial_dims: Union[int, Sequence[int]] = (-3, -2, -1), + ) -> None: + super().__init__() + self.normalize = normalize + self.return_magnitude = return_magnitude + self.return_phase = return_phase + self.radial_bins = radial_bins + self.max_frequency = max_frequency + + if isinstance(spatial_dims, int): + spatial_dims = (spatial_dims,) + self.spatial_dims = tuple(spatial_dims) + + # Validate parameters + if not 0.0 < max_frequency <= 1.0: + raise ValueError(f"max_frequency must be in (0.0, 1.0], got {max_frequency}") + if radial_bins is not None and radial_bins < 1: + raise ValueError(f"radial_bins must be >= 1, got {radial_bins}") + if not return_magnitude and not return_phase: + raise ValueError("At least one of return_magnitude or return_phase must be True") + + def _compute_radial_coordinates(self, shape: tuple[int, ...]) -> torch.Tensor: + """ + Compute radial distance from frequency domain center. + + Args: + shape: spatial dimensions (D, H, W) or (H, W, D) depending on dims order. + + Returns: + Tensor of same spatial shape with radial distances. + """ + # Create frequency coordinates for each dimension + coords = [] + for dim_size in shape: + # Create frequency range from -0.5 to 0.5 + freq = torch.fft.fftfreq(dim_size) + coords.append(freq) + + # Create meshgrid and compute radial distance + mesh = torch.meshgrid(coords, indexing="ij") + radial = torch.sqrt(sum(c**2 for c in mesh)) + + return radial + + def _compute_radial_spectrum(self, spectrum: torch.Tensor, radial_coords: torch.Tensor) -> torch.Tensor: + """ + Compute radial average of frequency spectrum. + + Args: + spectrum: complex frequency spectrum (flattened 1D array). + radial_coords: radial distance for each frequency coordinate (flattened 1D array). + + Returns: + Radial average of spectrum (1D array of length radial_bins). + """ + if self.radial_bins is None: + return spectrum + + # Bin radial coordinates + max_r = self.max_frequency * 0.5 # Maximum normalized frequency + bin_edges = torch.linspace(0, max_r, self.radial_bins + 1, device=spectrum.device) + + # Initialize output + result_real = torch.zeros(self.radial_bins, dtype=spectrum.real.dtype, device=spectrum.device) + result_imag = torch.zeros(self.radial_bins, dtype=spectrum.imag.dtype, device=spectrum.device) + + # Bin the frequencies - spectrum and radial_coords are both 1D + for i in range(self.radial_bins): + mask = (radial_coords >= bin_edges[i]) & (radial_coords < bin_edges[i + 1]) + if mask.any(): + # spectrum is 1D, so we can index it directly + result_real[i] = spectrum.real[mask].mean() + result_imag[i] = spectrum.imag[mask].mean() + + # Combine real and imaginary parts + result = torch.complex(result_real, result_imag) + + return result + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Apply 3D Radial Fourier Transform to input data. + + Args: + img: input medical image data. Expected shape: (..., D, H, W) + where D, H, W are spatial dimensions. + + Returns: + Transformed data in radial frequency domain. + """ + # Convert to tensor if needed + img_tensor, *_ = convert_data_type(img, torch.Tensor) + # Get spatial dimensions + spatial_shape = tuple(img_tensor.shape[d] for d in self.spatial_dims) + if len(spatial_shape) != 3: + raise ValueError(f"Expected 3 spatial dimensions, got {len(spatial_shape)}") + + # Compute 3D FFT + # Shift zero frequency to center and compute FFT + spectrum = fftn(ifftshift(img_tensor, dim=self.spatial_dims), dim=self.spatial_dims) + spectrum = fftshift(spectrum, dim=self.spatial_dims) + + # Normalize if requested + if self.normalize: + norm_factor = math.prod(spatial_shape) + spectrum = spectrum / norm_factor + + # Compute radial coordinates + radial_coords = self._compute_radial_coordinates(spatial_shape) + + # Apply radial binning if requested + if self.radial_bins is not None: + # Reshape for radial processing + orig_shape = spectrum.shape + # Move spatial dimensions to end for processing + spatial_indices = [d % len(orig_shape) for d in self.spatial_dims] + non_spatial_indices = [i for i in range(len(orig_shape)) if i not in spatial_indices] + + # Reshape to (non_spatial..., spatial_prod) + flat_shape = (*[orig_shape[i] for i in non_spatial_indices], -1) + spectrum_flat = spectrum.moveaxis(spatial_indices, [-3, -2, -1]).reshape(flat_shape) + radial_flat = radial_coords.flatten() + + # Get non-spatial dimensions (batch, channel, etc.) + non_spatial_dims = spectrum_flat.shape[:-1] + spatial_size = spectrum_flat.shape[-1] + + # Reshape to 2D: (non_spatial_product, spatial_size) + non_spatial_product = 1 + for dim in non_spatial_dims: + non_spatial_product *= dim + + spectrum_2d = spectrum_flat.reshape(non_spatial_product, spatial_size) + + # Process each non-spatial element (batch/channel combination) + results = [] + for i in range(non_spatial_product): + elem_spectrum = spectrum_2d[i] # Get spatial frequencies for this batch/channel + radial_result = self._compute_radial_spectrum(elem_spectrum, radial_flat) + results.append(radial_result) + + # Combine results and reshape back + spectrum = torch.stack(results, dim=0) + spectrum = spectrum.reshape(*non_spatial_dims, self.radial_bins) + else: + # Apply frequency mask if max_frequency < 1.0 + if self.max_frequency < 1.0: + freq_mask = radial_coords <= (self.max_frequency * 0.5) + # Expand mask to match spectrum dimensions + for _ in range(len(self.spatial_dims)): + freq_mask = freq_mask.unsqueeze(0) + spectrum = spectrum * freq_mask + + # Extract magnitude and/or phase as requested + output = None + if self.return_magnitude: + magnitude = torch.abs(spectrum) + output = magnitude if output is None else torch.cat([output, magnitude], dim=-1) + + if self.return_phase: + phase = torch.angle(spectrum) + output = phase if output is None else torch.cat([output, phase], dim=-1) + + # Convert back to original data type + output, *_ = convert_data_type(output, type(img)) + + return output + + def inverse(self, radial_data: NdarrayOrTensor, original_shape: tuple[int, ...]) -> NdarrayOrTensor: + """ + Inverse transform from radial frequency domain to spatial domain. + + Args: + radial_data: data in radial frequency domain. + original_shape: original spatial shape (D, H, W). + + Returns: + Reconstructed spatial data. + + Note: + This is an approximate inverse when radial_bins is used. + """ + if self.radial_bins is None: + # Direct inverse FFT + radial_tensor, *_ = convert_data_type(radial_data, torch.Tensor) + + # Separate magnitude and phase if needed + if self.return_magnitude and self.return_phase: + # Assuming they were concatenated along last dimension + split_idx = radial_tensor.shape[-1] // 2 + magnitude = radial_tensor[..., :split_idx] + phase = radial_tensor[..., split_idx:] + radial_tensor = torch.complex(magnitude * torch.cos(phase), magnitude * torch.sin(phase)) + + # Apply inverse FFT + result = ifftn(ifftshift(radial_tensor, dim=self.spatial_dims), dim=self.spatial_dims) + result = fftshift(result, dim=self.spatial_dims) + + if self.normalize: + result = result * math.prod(original_shape) + + result, *_ = convert_data_type(result.real, type(radial_data)) + return result + + else: + raise NotImplementedError( + "Exact inverse transform not available for radially binned data. " + "Consider using radial_bins=None for applications requiring inversion." + ) + + +class RadialFourierFeatures3D(Transform): + """ + Extract radial Fourier features for medical image analysis. + + Computes multiple radial Fourier transforms with different parameters + to create a comprehensive frequency feature representation. + + Args: + n_bins_list: list of radial bin counts to compute. + return_types: list of return types: 'magnitude', 'phase', or 'complex'. + normalize: if True, normalize the output. + + Returns: + Concatenated radial Fourier features. + + Example: + >>> transform = RadialFourierFeatures3D(n_bins_list=[32, 64, 128]) + >>> image = torch.randn(1, 128, 128, 96) + >>> features = transform(image) # Shape: (1, 32+64+128=224) + """ + + def __init__( + self, + n_bins_list: Sequence[int] = (32, 64, 128), + return_types: Sequence[str] = ("magnitude",), + normalize: bool = True, + ) -> None: + super().__init__() + self.n_bins_list = n_bins_list + self.return_types = return_types + self.normalize = normalize + + # Create individual transforms + self.transforms = [] + for n_bins in n_bins_list: + for return_type in return_types: + transform = RadialFourier3D( + normalize=normalize, + return_magnitude=(return_type in ["magnitude", "complex"]), + return_phase=(return_type in ["phase", "complex"]), + radial_bins=n_bins, + ) + self.transforms.append(transform) + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """Extract radial Fourier features.""" + features = [] + for transform in self.transforms: + feat = transform(img) + features.append(feat) + + # Concatenate along last dimension + if features: + # Convert all features to tensors if any are numpy arrays + features_tensors = [] + for feat in features: + if isinstance(feat, np.ndarray): + features_tensors.append(torch.from_numpy(feat)) + else: + features_tensors.append(feat) + output = torch.cat(features_tensors, dim=-1) + else: + output = img + + # Convert to original type if needed + if isinstance(img, np.ndarray): + output = output.cpu().numpy() + + return output diff --git a/tests/test_radial_fourier.py b/tests/test_radial_fourier.py new file mode 100644 index 0000000000..6b2caa0810 --- /dev/null +++ b/tests/test_radial_fourier.py @@ -0,0 +1,196 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the 3D Radial Fourier Transform. +""" + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RadialFourier3D, RadialFourierFeatures3D +from monai.utils import set_determinism + + +class TestRadialFourier3D(unittest.TestCase): + """Test cases for RadialFourier3D transform.""" + + def setUp(self): + """Set up test fixtures.""" + set_determinism(seed=42) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # Create test data + self.test_image_3d = torch.randn(1, 32, 64, 64, device=self.device) # Batch, D, H, W + self.test_image_4d = torch.randn(2, 1, 48, 64, 64, device=self.device) # Batch, Channel, D, H, W + + def tearDown(self): + """Clean up after tests.""" + set_determinism(seed=None) + + @parameterized.expand( + [ + [{"radial_bins": 32, "return_magnitude": True}, (1, 32)], + [{"radial_bins": 64, "return_magnitude": True, "return_phase": True}, (1, 128)], + [{"radial_bins": None, "return_magnitude": True}, (1, 32, 64, 64)], + [{"radial_bins": 16, "return_magnitude": True, "max_frequency": 0.5}, (1, 16)], + ] + ) + def test_output_shape(self, params, expected_shape): + """Test that output shape matches expectations.""" + transform = RadialFourier3D(**params) + result = transform(self.test_image_3d) + self.assertEqual(result.shape, expected_shape) + + def test_complex_input(self): + """Test with complex-valued input.""" + complex_image = torch.complex( + torch.randn(1, 32, 64, 64, device=self.device), + torch.randn(1, 32, 64, 64, device=self.device), + ) + transform = RadialFourier3D(radial_bins=32, return_magnitude=True) + result = transform(complex_image) + self.assertEqual(result.shape, (1, 32)) + + def test_normalization(self): + """Test normalization affects output scale.""" + transform1 = RadialFourier3D(radial_bins=32, normalize=True) + transform2 = RadialFourier3D(radial_bins=32, normalize=False) + + result1 = transform1(self.test_image_3d) + result2 = transform2(self.test_image_3d) + + # Normalized result should be smaller + self.assertLess(torch.abs(result1).mean().item(), torch.abs(result2).mean().item()) + + def test_inverse_transform(self): + """Test approximate inverse transform.""" + # Use full spectrum for invertibility + transform = RadialFourier3D(radial_bins=None, normalize=True) + + # Forward transform + spectrum = transform(self.test_image_3d) + + # Inverse transform + reconstructed = transform.inverse(spectrum, self.test_image_3d.shape[-3:]) + + # Should have same shape + self.assertEqual(reconstructed.shape, self.test_image_3d.shape) + + def test_deterministic(self): + """Test that transform is deterministic.""" + transform = RadialFourier3D(radial_bins=32) + + result1 = transform(self.test_image_3d) + result2 = transform(self.test_image_3d) + + self.assertTrue(torch.allclose(result1, result2, rtol=1e-5)) + + def test_numpy_input(self): + """Test that numpy arrays are accepted.""" + np_image = self.test_image_3d.cpu().numpy() + transform = RadialFourier3D(radial_bins=32) + + result = transform(np_image) + self.assertIsInstance(result, np.ndarray) + self.assertEqual(result.shape, (1, 32)) + + @parameterized.expand( + [ + [{"max_frequency": -0.1}], # Invalid negative + [{"max_frequency": 1.5}], # Invalid > 1.0 + [{"radial_bins": 0}], # Invalid zero bins + [{"return_magnitude": False, "return_phase": False}], # No output requested + ] + ) + def test_invalid_parameters(self, params): + """Test that invalid parameters raise errors.""" + with self.assertRaises(ValueError): + RadialFourier3D(**params) + + def test_spatial_dims_parameter(self): + """Test custom spatial dimensions.""" + # Test with 4D input but spatial dims in middle + image = torch.randn(2, 32, 64, 64, 3, device=self.device) # Batch, D, H, W, Channels + transform = RadialFourier3D(radial_bins=16, spatial_dims=(1, 2, 3)) + result = transform(image) + self.assertEqual(result.shape, (2, 3, 16)) + + def test_batch_processing(self): + """Test processing batch of images.""" + batch_size = 4 + batch_image = torch.randn(batch_size, 32, 64, 64, device=self.device) + transform = RadialFourier3D(radial_bins=32) + result = transform(batch_image) + self.assertEqual(result.shape, (batch_size, 32)) + + +class TestRadialFourierFeatures3D(unittest.TestCase): + """Test cases for RadialFourierFeatures3D transform.""" + + def setUp(self): + """Set up test fixtures.""" + set_determinism(seed=42) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.test_image = torch.randn(2, 32, 64, 64, device=self.device) + + def tearDown(self): + """Clean up after tests.""" + set_determinism(seed=None) + + def test_feature_extraction(self): + """Test multi-scale feature extraction.""" + transform = RadialFourierFeatures3D(n_bins_list=[16, 32, 64], return_types=["magnitude"]) + + features = transform(self.test_image) + expected_features = 16 + 32 + 64 # Sum of all bins + + self.assertEqual(features.shape, (2, expected_features)) + + def test_multiple_return_types(self): + """Test with multiple return types.""" + transform = RadialFourierFeatures3D(n_bins_list=[16, 32], return_types=["magnitude", "phase"]) + + features = transform(self.test_image) + # Each bin count appears twice (magnitude and phase) + expected_features = (16 + 32) * 2 + + self.assertEqual(features.shape, (2, expected_features)) + + def test_complex_output(self): + """Test complex output type.""" + transform = RadialFourierFeatures3D(n_bins_list=[16], return_types=["complex"]) + + features = transform(self.test_image) + # Complex returns both magnitude and phase concatenated + self.assertEqual(features.shape, (2, 16 * 2)) + + def test_empty_bins_list(self): + """Test with empty bins list.""" + transform = RadialFourierFeatures3D(n_bins_list=[], return_types=["magnitude"]) + features = transform(self.test_image) + # Should return original image when no transforms + self.assertEqual(features.shape, self.test_image.shape) + + def test_numpy_compatibility(self): + """Test with numpy input.""" + np_image = self.test_image.cpu().numpy() + transform = RadialFourierFeatures3D(n_bins_list=[16, 32]) + + features = transform(np_image) + self.assertIsInstance(features, np.ndarray) + self.assertEqual(features.shape, (2, 16 + 32)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transforms/signal/__init__.py b/tests/transforms/signal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From d12876d85368de52f01ed50e23cb5ba08aac6605 Mon Sep 17 00:00:00 2001 From: Hitendrasinh Rathod Date: Wed, 24 Dec 2025 20:33:46 +0000 Subject: [PATCH 2/5] Fix CodeRabbit review issues for radial Fourier transform - Add device parameter to _compute_radial_coordinates to prevent CPU/GPU mismatch - Fix frequency mask expansion for multi-dimensional inputs - Add reconstruction accuracy test assertion (with proper magnitude+phase for inverse) - Add Raises section to docstring - Remove unused import - Address all review comments --- monai/transforms/signal/radial_fourier.py | 18 ++++++++++++------ tests/test_radial_fourier.py | 5 ++++- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/monai/transforms/signal/radial_fourier.py b/monai/transforms/signal/radial_fourier.py index e58aefe7e5..e85afe009e 100644 --- a/monai/transforms/signal/radial_fourier.py +++ b/monai/transforms/signal/radial_fourier.py @@ -25,10 +25,10 @@ from monai.config import NdarrayOrTensor from monai.transforms.transform import Transform -from monai.utils import convert_data_type, optional_import +from monai.utils import convert_data_type # Optional imports for type checking -spatial, _ = optional_import("monai.utils", name="spatial") +# spatial, _ = optional_import("monai.utils", name="spatial") # Commented out unused import class RadialFourier3D(Transform): @@ -59,6 +59,10 @@ class RadialFourier3D(Transform): >>> transform = RadialFourier3D(radial_bins=64, return_magnitude=True) >>> image = torch.randn(1, 128, 128, 96) # Batch, Height, Width, Depth >>> result = transform(image) # Shape: (1, 64) + + Raises: + ValueError: If max_frequency not in (0.0, 1.0], radial_bins < 1, or both + return_magnitude and return_phase are False. """ def __init__( @@ -89,12 +93,13 @@ def __init__( if not return_magnitude and not return_phase: raise ValueError("At least one of return_magnitude or return_phase must be True") - def _compute_radial_coordinates(self, shape: tuple[int, ...]) -> torch.Tensor: + def _compute_radial_coordinates(self, shape: tuple[int, ...], device: torch.device = None) -> torch.Tensor: """ Compute radial distance from frequency domain center. Args: shape: spatial dimensions (D, H, W) or (H, W, D) depending on dims order. + device: device to create tensor on. Returns: Tensor of same spatial shape with radial distances. @@ -103,7 +108,7 @@ def _compute_radial_coordinates(self, shape: tuple[int, ...]) -> torch.Tensor: coords = [] for dim_size in shape: # Create frequency range from -0.5 to 0.5 - freq = torch.fft.fftfreq(dim_size) + freq = torch.fft.fftfreq(dim_size, device=device) coords.append(freq) # Create meshgrid and compute radial distance @@ -176,7 +181,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: spectrum = spectrum / norm_factor # Compute radial coordinates - radial_coords = self._compute_radial_coordinates(spatial_shape) + radial_coords = self._compute_radial_coordinates(spatial_shape, device=spectrum.device) # Apply radial binning if requested if self.radial_bins is not None: @@ -217,7 +222,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: if self.max_frequency < 1.0: freq_mask = radial_coords <= (self.max_frequency * 0.5) # Expand mask to match spectrum dimensions - for _ in range(len(self.spatial_dims)): + n_non_spatial = len(spectrum.shape) - len(spatial_shape) + for _ in range(n_non_spatial): freq_mask = freq_mask.unsqueeze(0) spectrum = spectrum * freq_mask diff --git a/tests/test_radial_fourier.py b/tests/test_radial_fourier.py index 6b2caa0810..a3911ea44f 100644 --- a/tests/test_radial_fourier.py +++ b/tests/test_radial_fourier.py @@ -76,7 +76,7 @@ def test_normalization(self): def test_inverse_transform(self): """Test approximate inverse transform.""" # Use full spectrum for invertibility - transform = RadialFourier3D(radial_bins=None, normalize=True) + transform = RadialFourier3D(radial_bins=None, normalize=True, return_magnitude=True, return_phase=True) # Forward transform spectrum = transform(self.test_image_3d) @@ -87,6 +87,9 @@ def test_inverse_transform(self): # Should have same shape self.assertEqual(reconstructed.shape, self.test_image_3d.shape) + # Should approximately reconstruct original + self.assertTrue(torch.allclose(reconstructed, self.test_image_3d, atol=1e-5)) + def test_deterministic(self): """Test that transform is deterministic.""" transform = RadialFourier3D(radial_bins=32) From 66c468923dc5e5e668f2911abc625fa59d215606 Mon Sep 17 00:00:00 2001 From: Hitendrasinh Rathod Date: Wed, 24 Dec 2025 20:45:28 +0000 Subject: [PATCH 3/5] chore(tests): remove unused test fixture test_image_4d --- tests/test_radial_fourier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_radial_fourier.py b/tests/test_radial_fourier.py index a3911ea44f..15c62920de 100644 --- a/tests/test_radial_fourier.py +++ b/tests/test_radial_fourier.py @@ -32,7 +32,7 @@ def setUp(self): # Create test data self.test_image_3d = torch.randn(1, 32, 64, 64, device=self.device) # Batch, D, H, W - self.test_image_4d = torch.randn(2, 1, 48, 64, 64, device=self.device) # Batch, Channel, D, H, W + def tearDown(self): """Clean up after tests.""" From 8fb969e6f90e7f15ae8f9a4a4c0155253ecdcd1d Mon Sep 17 00:00:00 2001 From: Hitendrasinh Rathod Date: Wed, 24 Dec 2025 21:05:33 +0000 Subject: [PATCH 4/5] style: fix import sorting and formatting issues --- monai/transforms/signal/radial_fourier.py | 3 +-- tests/test_radial_fourier.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/transforms/signal/radial_fourier.py b/monai/transforms/signal/radial_fourier.py index e85afe009e..0e685aa457 100644 --- a/monai/transforms/signal/radial_fourier.py +++ b/monai/transforms/signal/radial_fourier.py @@ -15,9 +15,8 @@ from __future__ import annotations import math -from typing import Optional, Union - from collections.abc import Sequence +from typing import Optional, Union import numpy as np import torch diff --git a/tests/test_radial_fourier.py b/tests/test_radial_fourier.py index 15c62920de..f6e4d081c9 100644 --- a/tests/test_radial_fourier.py +++ b/tests/test_radial_fourier.py @@ -12,6 +12,8 @@ Tests for the 3D Radial Fourier Transform. """ +from __future__ import annotations + import unittest import numpy as np @@ -33,8 +35,6 @@ def setUp(self): # Create test data self.test_image_3d = torch.randn(1, 32, 64, 64, device=self.device) # Batch, D, H, W - - def tearDown(self): """Clean up after tests.""" set_determinism(seed=None) From ead08157662e096bf22d3042e87dba975910d096 Mon Sep 17 00:00:00 2001 From: Hitendrasinh Rathod Date: Wed, 24 Dec 2025 21:21:11 +0000 Subject: [PATCH 5/5] fix(tests): correct setUp/tearDown structure in test_radial_fourier --- tests/test_radial_fourier.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_radial_fourier.py b/tests/test_radial_fourier.py index f6e4d081c9..c1ff2933f3 100644 --- a/tests/test_radial_fourier.py +++ b/tests/test_radial_fourier.py @@ -35,6 +35,7 @@ def setUp(self): # Create test data self.test_image_3d = torch.randn(1, 32, 64, 64, device=self.device) # Batch, D, H, W + def tearDown(self): """Clean up after tests.""" set_determinism(seed=None)