diff --git a/docs/source/augmentation.module.rst b/docs/source/augmentation.module.rst index 99b0d7e233f..e87fef7c253 100644 --- a/docs/source/augmentation.module.rst +++ b/docs/source/augmentation.module.rst @@ -61,6 +61,7 @@ Geometric .. autoclass:: RandomHorizontalFlip .. autoclass:: RandomPerspective .. autoclass:: RandomResizedCrop +.. autoclass:: RandomRotation90 .. autoclass:: RandomRotation .. autoclass:: RandomShear .. autoclass:: RandomThinPlateSpline diff --git a/kornia/augmentation/_2d/geometric/__init__.py b/kornia/augmentation/_2d/geometric/__init__.py index bfea3dac0a9..d51a7524eee 100644 --- a/kornia/augmentation/_2d/geometric/__init__.py +++ b/kornia/augmentation/_2d/geometric/__init__.py @@ -8,7 +8,7 @@ from kornia.augmentation._2d.geometric.perspective import RandomPerspective from kornia.augmentation._2d.geometric.resize import LongestMaxSize, Resize, SmallestMaxSize from kornia.augmentation._2d.geometric.resized_crop import RandomResizedCrop -from kornia.augmentation._2d.geometric.rotation import RandomRotation +from kornia.augmentation._2d.geometric.rotation import RandomRotation, RandomRotation90 from kornia.augmentation._2d.geometric.shear import RandomShear from kornia.augmentation._2d.geometric.thin_plate_spline import RandomThinPlateSpline from kornia.augmentation._2d.geometric.translate import RandomTranslate diff --git a/kornia/augmentation/_2d/geometric/rotation.py b/kornia/augmentation/_2d/geometric/rotation.py index f2372b84873..c109884bbc9 100644 --- a/kornia/augmentation/_2d/geometric/rotation.py +++ b/kornia/augmentation/_2d/geometric/rotation.py @@ -110,3 +110,104 @@ def inverse_transform( transform=as_tensor(transform, device=input.device, dtype=input.dtype), flags=flags, ) + + +class RandomRotation90(GeometricAugmentationBase2D): + r"""Apply a random 90 * n degree rotation to a tensor image or a batch of tensor images. + + Args: + times: the range of n times 90 degree rotation needs to be applied. + resample: Default: the interpolation mode. + same_on_batch: apply the same transformation across the batch. + align_corners: interpolation flag. + p: probability of applying the transformation. + keepdim: whether to keep the output shape the same as input (True) or broadcast it + to the batch form (False). + + Shape: + - Input: :math:`(C, H, W)` or :math:`(B, C, H, W)`, Optional: :math:`(B, 3, 3)` + - Output: :math:`(B, C, H, W)` + + .. note:: + This function internally uses :func:`kornia.geometry.transform.affine`. This version is relatively + slow as it operates based on affine transformations. + + Examples: + >>> rng = torch.manual_seed(1) + >>> torch.set_printoptions(sci_mode=False) + >>> input = torch.tensor([[1., 0., 0., 2.], + ... [0., 0., 0., 0.], + ... [0., 1., 2., 0.], + ... [0., 0., 1., 2.]]) + >>> aug = RandomRotation90(times=(1, 1), p=1.) + >>> out = aug(input) + >>> out + tensor([[[[ 2.0000, 0.0000, 0.0000, 2.0000], + [ 0.0000, 0.0000, 2.0000, 1.0000], + [ 0.0000, 0.0000, 1.0000, 0.0000], + [ 1.0000, 0.0000, 0.0000, 0.0000]]]]) + >>> aug.transform_matrix + tensor([[[ -0.0000, 1.0000, 0.0000], + [ -1.0000, -0.0000, 3.0000], + [ 0.0000, 0.0000, 1.0000]]]) + >>> inv = aug.inverse(out) + >>> torch.set_printoptions(profile='default') + + To apply the exact augmenation again, you may take the advantage of the previous parameter state: + >>> input = torch.randn(1, 3, 32, 32) + >>> aug = RandomRotation90(times=(-1, 1), p=1.) + >>> (aug(input) == aug(input, params=aug._params)).all() + tensor(True) + """ + + def __init__( + self, + times: Tuple[int, int], + resample: Union[str, int, Resample] = Resample.BILINEAR.name, + same_on_batch: bool = False, + align_corners: bool = True, + p: float = 0.5, + keepdim: bool = False, + ) -> None: + super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim) + self._param_generator = rg.PlainUniformGenerator((times, "times", 0.0, (-3, 3))) + + self.flags = {"resample": Resample.get(resample), "align_corners": align_corners} + + def compute_transformation(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor: + # TODO: Update to use `get_rotation_matrix2d` + angles: Tensor = 90.0 * params["times"].round().to(input) + + center: Tensor = _compute_tensor_center(input) + rotation_mat: Tensor = _compute_rotation_matrix(angles, center.expand(angles.shape[0], -1)) + + # rotation_mat is B x 2 x 3 and we need a B x 3 x 3 matrix + trans_mat: Tensor = eye_like(3, input, shared_memory=False) + trans_mat[:, 0] = rotation_mat[:, 0] + trans_mat[:, 1] = rotation_mat[:, 1] + + return trans_mat + + def apply_transform( + self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None + ) -> Tensor: + if not isinstance(transform, Tensor): + raise TypeError(f"Expected the `transform` be a Tensor. Got {type(transform)}.") + + return affine(input, transform[..., :2, :3], flags["resample"].name.lower(), "zeros", flags["align_corners"]) + + def inverse_transform( + self, + input: Tensor, + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + size: Optional[Tuple[int, int]] = None, + ) -> Tensor: + if not isinstance(transform, Tensor): + raise TypeError(f"Expected the `transform` be a Tensor. Got {type(transform)}.") + return self.apply_transform( + input, + params=self._params, + transform=as_tensor(transform, device=input.device, dtype=input.dtype), + flags=flags, + ) diff --git a/kornia/augmentation/__init__.py b/kornia/augmentation/__init__.py index f0338975b94..466dd6ffe4e 100644 --- a/kornia/augmentation/__init__.py +++ b/kornia/augmentation/__init__.py @@ -49,6 +49,7 @@ RandomResizedCrop, RandomRGBShift, RandomRotation, + RandomRotation90, RandomSaltAndPepperNoise, RandomSaturation, RandomSharpness, @@ -135,6 +136,7 @@ "RandomPlasmaBrightness", "RandomPlasmaContrast", "RandomResizedCrop", + "RandomRotation90", "RandomRotation", "RandomRGBShift", "RandomSaltAndPepperNoise", diff --git a/tests/augmentation/test_augmentation.py b/tests/augmentation/test_augmentation.py index 939cdc6e3bb..9873d9bf1ef 100644 --- a/tests/augmentation/test_augmentation.py +++ b/tests/augmentation/test_augmentation.py @@ -49,6 +49,7 @@ RandomResizedCrop, RandomRGBShift, RandomRotation, + RandomRotation90, RandomSaltAndPepperNoise, RandomSaturation, RandomSnow, @@ -697,6 +698,83 @@ def test_exception(self): self._create_augmentation_from_params(degrees=(360.0, -360.0)) +class TestRandomRotation90(CommonTests): + possible_params: Dict["str", Tuple] = { + "times": ((-3, 3), (1, 1)), + "resample": (0, Resample.BILINEAR.name, Resample.BILINEAR), + "align_corners": (False, True), + } + _augmentation_cls = RandomRotation90 + _default_param_set: Dict["str", Any] = { + "times": (-3, 3), + "align_corners": True, + } + + @pytest.fixture(params=[_default_param_set], scope="class") + def param_set(self, request): + return request.param + + @pytest.mark.parametrize( + "input_shape,expected_output_shape", + [((3, 4, 5), (1, 3, 4, 5)), ((2, 3, 4, 5), (2, 3, 4, 5))], + ) + def test_cardinality(self, input_shape, expected_output_shape): + self._test_cardinality_implementation( + input_shape=input_shape, + expected_output_shape=expected_output_shape, + params=self._default_param_set, + ) + + def test_random_p_1(self): + torch.manual_seed(42) + + input_tensor = torch.tensor( + [[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], + device=self.device, + dtype=self.dtype, + ) + expected_output = torch.tensor( + [[[[0.3, 0.6, 0.9], [0.2, 0.5, 0.8], [0.1, 0.4, 0.7]]]], + device=self.device, + dtype=self.dtype, + ) + + parameters = {"times": (1, 1), "align_corners": True} + self._test_random_p_1_implementation( + input_tensor=input_tensor, + expected_output=expected_output, + params=parameters, + ) + + def test_batch(self): + if self.dtype == torch.float16: + pytest.skip("not work for half-precision") + + torch.manual_seed(12) + + input_tensor = torch.tensor( + [[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]], + device=self.device, + dtype=self.dtype, + ).repeat((2, 1, 1, 1)) + expected_output = input_tensor + expected_transformation = kornia.eye_like(3, input_tensor) + parameters = {"times": (0, 0), "align_corners": True} + self._test_random_p_1_implementation( + input_tensor=input_tensor, + expected_output=expected_output, + expected_transformation=expected_transformation, + params=parameters, + ) + + def test_exception(self): + # Wrong type + with pytest.raises(TypeError): + self._create_augmentation_from_params(times="") + with pytest.raises(ValueError): + self._create_augmentation_from_params(times=(30, 60), align_corners=0) + + class TestRandomGrayscaleAlternative(CommonTests): possible_params: Dict["str", Tuple] = {}