Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/augmentation.module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Geometric
.. autoclass:: RandomHorizontalFlip
.. autoclass:: RandomPerspective
.. autoclass:: RandomResizedCrop
.. autoclass:: RandomRotation90
.. autoclass:: RandomRotation
.. autoclass:: RandomShear
.. autoclass:: RandomThinPlateSpline
Expand Down
2 changes: 1 addition & 1 deletion kornia/augmentation/_2d/geometric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from kornia.augmentation._2d.geometric.perspective import RandomPerspective
from kornia.augmentation._2d.geometric.resize import LongestMaxSize, Resize, SmallestMaxSize
from kornia.augmentation._2d.geometric.resized_crop import RandomResizedCrop
from kornia.augmentation._2d.geometric.rotation import RandomRotation
from kornia.augmentation._2d.geometric.rotation import RandomRotation, RandomRotation90
from kornia.augmentation._2d.geometric.shear import RandomShear
from kornia.augmentation._2d.geometric.thin_plate_spline import RandomThinPlateSpline
from kornia.augmentation._2d.geometric.translate import RandomTranslate
Expand Down
101 changes: 101 additions & 0 deletions kornia/augmentation/_2d/geometric/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,104 @@ def inverse_transform(
transform=as_tensor(transform, device=input.device, dtype=input.dtype),
flags=flags,
)


class RandomRotation90(GeometricAugmentationBase2D):
r"""Apply a random 90 * n degree rotation to a tensor image or a batch of tensor images.

Args:
times: the range of n times 90 degree rotation needs to be applied.
resample: Default: the interpolation mode.
same_on_batch: apply the same transformation across the batch.
align_corners: interpolation flag.
p: probability of applying the transformation.
keepdim: whether to keep the output shape the same as input (True) or broadcast it
to the batch form (False).

Shape:
- Input: :math:`(C, H, W)` or :math:`(B, C, H, W)`, Optional: :math:`(B, 3, 3)`
- Output: :math:`(B, C, H, W)`

.. note::
This function internally uses :func:`kornia.geometry.transform.affine`. This version is relatively
slow as it operates based on affine transformations.

Examples:
>>> rng = torch.manual_seed(1)
>>> torch.set_printoptions(sci_mode=False)
>>> input = torch.tensor([[1., 0., 0., 2.],
... [0., 0., 0., 0.],
... [0., 1., 2., 0.],
... [0., 0., 1., 2.]])
>>> aug = RandomRotation90(times=(1, 1), p=1.)
>>> out = aug(input)
>>> out
tensor([[[[ 2.0000, 0.0000, 0.0000, 2.0000],
[ 0.0000, 0.0000, 2.0000, 1.0000],
[ 0.0000, 0.0000, 1.0000, 0.0000],
[ 1.0000, 0.0000, 0.0000, 0.0000]]]])
>>> aug.transform_matrix
tensor([[[ -0.0000, 1.0000, 0.0000],
[ -1.0000, -0.0000, 3.0000],
[ 0.0000, 0.0000, 1.0000]]])
>>> inv = aug.inverse(out)
>>> torch.set_printoptions(profile='default')

To apply the exact augmenation again, you may take the advantage of the previous parameter state:
>>> input = torch.randn(1, 3, 32, 32)
>>> aug = RandomRotation90(times=(-1, 1), p=1.)
>>> (aug(input) == aug(input, params=aug._params)).all()
tensor(True)
"""

def __init__(
self,
times: Tuple[int, int],
resample: Union[str, int, Resample] = Resample.BILINEAR.name,
same_on_batch: bool = False,
align_corners: bool = True,
p: float = 0.5,
keepdim: bool = False,
) -> None:
super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim)
self._param_generator = rg.PlainUniformGenerator((times, "times", 0.0, (-3, 3)))

self.flags = {"resample": Resample.get(resample), "align_corners": align_corners}

def compute_transformation(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor:
# TODO: Update to use `get_rotation_matrix2d`
angles: Tensor = 90.0 * params["times"].round().to(input)

center: Tensor = _compute_tensor_center(input)
rotation_mat: Tensor = _compute_rotation_matrix(angles, center.expand(angles.shape[0], -1))

# rotation_mat is B x 2 x 3 and we need a B x 3 x 3 matrix
trans_mat: Tensor = eye_like(3, input, shared_memory=False)
trans_mat[:, 0] = rotation_mat[:, 0]
trans_mat[:, 1] = rotation_mat[:, 1]

return trans_mat

def apply_transform(
self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
) -> Tensor:
if not isinstance(transform, Tensor):
raise TypeError(f"Expected the `transform` be a Tensor. Got {type(transform)}.")

return affine(input, transform[..., :2, :3], flags["resample"].name.lower(), "zeros", flags["align_corners"])

def inverse_transform(
self,
input: Tensor,
flags: Dict[str, Any],
transform: Optional[Tensor] = None,
size: Optional[Tuple[int, int]] = None,
) -> Tensor:
if not isinstance(transform, Tensor):
raise TypeError(f"Expected the `transform` be a Tensor. Got {type(transform)}.")
return self.apply_transform(
input,
params=self._params,
transform=as_tensor(transform, device=input.device, dtype=input.dtype),
flags=flags,
)
2 changes: 2 additions & 0 deletions kornia/augmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
RandomResizedCrop,
RandomRGBShift,
RandomRotation,
RandomRotation90,
RandomSaltAndPepperNoise,
RandomSaturation,
RandomSharpness,
Expand Down Expand Up @@ -135,6 +136,7 @@
"RandomPlasmaBrightness",
"RandomPlasmaContrast",
"RandomResizedCrop",
"RandomRotation90",
"RandomRotation",
"RandomRGBShift",
"RandomSaltAndPepperNoise",
Expand Down
78 changes: 78 additions & 0 deletions tests/augmentation/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
RandomResizedCrop,
RandomRGBShift,
RandomRotation,
RandomRotation90,
RandomSaltAndPepperNoise,
RandomSaturation,
RandomSnow,
Expand Down Expand Up @@ -697,6 +698,83 @@ def test_exception(self):
self._create_augmentation_from_params(degrees=(360.0, -360.0))


class TestRandomRotation90(CommonTests):
possible_params: Dict["str", Tuple] = {
"times": ((-3, 3), (1, 1)),
"resample": (0, Resample.BILINEAR.name, Resample.BILINEAR),
"align_corners": (False, True),
}
_augmentation_cls = RandomRotation90
_default_param_set: Dict["str", Any] = {
"times": (-3, 3),
"align_corners": True,
}

@pytest.fixture(params=[_default_param_set], scope="class")
def param_set(self, request):
return request.param

@pytest.mark.parametrize(
"input_shape,expected_output_shape",
[((3, 4, 5), (1, 3, 4, 5)), ((2, 3, 4, 5), (2, 3, 4, 5))],
)
def test_cardinality(self, input_shape, expected_output_shape):
self._test_cardinality_implementation(
input_shape=input_shape,
expected_output_shape=expected_output_shape,
params=self._default_param_set,
)

def test_random_p_1(self):
torch.manual_seed(42)

input_tensor = torch.tensor(
[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]],
device=self.device,
dtype=self.dtype,
)
expected_output = torch.tensor(
[[[[0.3, 0.6, 0.9], [0.2, 0.5, 0.8], [0.1, 0.4, 0.7]]]],
device=self.device,
dtype=self.dtype,
)

parameters = {"times": (1, 1), "align_corners": True}
self._test_random_p_1_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
params=parameters,
)

def test_batch(self):
if self.dtype == torch.float16:
pytest.skip("not work for half-precision")

torch.manual_seed(12)

input_tensor = torch.tensor(
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]],
device=self.device,
dtype=self.dtype,
).repeat((2, 1, 1, 1))
expected_output = input_tensor
expected_transformation = kornia.eye_like(3, input_tensor)
parameters = {"times": (0, 0), "align_corners": True}
self._test_random_p_1_implementation(
input_tensor=input_tensor,
expected_output=expected_output,
expected_transformation=expected_transformation,
params=parameters,
)

def test_exception(self):
# Wrong type
with pytest.raises(TypeError):
self._create_augmentation_from_params(times="")
with pytest.raises(ValueError):
self._create_augmentation_from_params(times=(30, 60), align_corners=0)


class TestRandomGrayscaleAlternative(CommonTests):
possible_params: Dict["str", Tuple] = {}

Expand Down