From ab681131b1bb5f8100dd5311d185fed1447d6e73 Mon Sep 17 00:00:00 2001 From: isalia20 Date: Tue, 27 Aug 2024 01:32:35 +0400 Subject: [PATCH 1/3] compute area for all types of boxes --- kornia/geometry/boxes.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/kornia/geometry/boxes.py b/kornia/geometry/boxes.py index c08ba867d51..a9bd12ae506 100644 --- a/kornia/geometry/boxes.py +++ b/kornia/geometry/boxes.py @@ -392,9 +392,19 @@ def filter_boxes_by_area( def compute_area(self) -> torch.Tensor: """Returns :math:`(B, N)`.""" - w = self._data[..., 1, 0] - self._data[..., 0, 0] - h = self._data[..., 2, 1] - self._data[..., 0, 1] - return (w * h).unsqueeze(0) if self._data.ndim == 3 else (w * h) + coords = self._data.view((-1, 4, 2)) if self._data.ndim == 4 else self._data + # calculate centroid of the box + centroid = coords.mean(dim=1, keepdim=True) + # calculate the angle from centroid to each corner + angles = torch.atan2(coords[..., 1] - centroid[..., 1], coords[..., 0] - centroid[..., 0]) + # sort the corners by angle to get an order for shoelace formula + _, clockwise_indices = torch.sort(angles, dim=1, descending=True) + # gather the corners in the new order + ordered_corners = torch.gather(coords, 1, clockwise_indices.unsqueeze(-1).expand(-1, -1, 2)) + x, y = ordered_corners[..., 0], ordered_corners[..., 1] + # Gaussian/Shoelace formula https://en.wikipedia.org/wiki/Shoelace_formula + area = 0.5 * torch.abs(torch.sum((x * torch.roll(y, 1, 1)) - (y * torch.roll(x, 1, 1)), dim=1)) + return area.view(self._data.shape[:2]) if self._data.ndim == 4 else area @classmethod def from_tensor( From 5e8f130ac7931bea1d803d7649fc19cf2ae8a977 Mon Sep 17 00:00:00 2001 From: isalia20 Date: Wed, 28 Aug 2024 22:14:37 +0400 Subject: [PATCH 2/3] tests for compute area of box --- tests/geometry/test_boxes.py | 51 ++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/geometry/test_boxes.py b/tests/geometry/test_boxes.py index 3f4f8c706dc..da4888ada47 100644 --- a/tests/geometry/test_boxes.py +++ b/tests/geometry/test_boxes.py @@ -277,7 +277,58 @@ def apply_boxes_method(tensor: torch.Tensor, method: str, **kwargs): self.gradcheck(partial(apply_boxes_method, method="get_boxes_shape"), (t_boxes1,)) self.gradcheck(lambda x: Boxes.from_tensor(x, mode="xyxy_plus").data, (t_boxes_xyxy,)) self.gradcheck(lambda x: Boxes.from_tensor(x, mode="xywh").data, (t_boxes_xyxy1,)) + + def test_compute_area(self): + # Rectangle + box_1 = [ + [0.0, 0.0], + [100.0, 0.0], + [100.0, 50.0], + [0.0, 50.0] + ] + # Trapezoid + box_2 = [ + [0.0, 0.0], + [60.0, 0.0], + [40.0, 50.0], + [20.0, 50.0] + ] + # Parallelogram + box_3 = [ + [0.0, 0.0], + [100.0, 0.0], + [120.0, 50.0], + [20.0, 50.0] + ] + # Random quadrilateral + box_4 = [ + [50.0, 50.0], + [150.0, 250.0], + [0.0, 500.0], + [27.0, 80], + ] + # Random quadrilateral + box_5 = [ + [0.0, 0.0], + [150.0, 0.0], + [150.0, 150.0], + [0.0, 0.5], + ] + # Rectangle with minus coordinates + box_6 = [ + [-500.0, -500.0], + [-300.0, -500.0], + [-300.0, -300.0], + [-500.0, -300.0] + ] + expected_values = [5000.0, 2000.0, 5000.0, 31925.0, 11287.5, 40000.0] + box_coordinates = torch.tensor([box_1, box_2, box_3, box_4, box_5, box_6]) + computed_areas = Boxes(box_coordinates).compute_area().tolist() + computed_areas_w_batch = Boxes(box_coordinates.reshape(2, 3, 4, 2)).compute_area().tolist() + flattened_computed_areas_w_batch = [area for batch in computed_areas_w_batch for area in batch] + assert all([computed_area == expected_area for computed_area, expected_area in zip(computed_areas, expected_values)]) + assert all([computed_area == expected_area for computed_area, expected_area in zip(flattened_computed_areas_w_batch, expected_values)]) class TestTransformBoxes2D(BaseTester): def test_transform_boxes(self, device, dtype): From 13419c866976940fdba593d464a99147866059f0 Mon Sep 17 00:00:00 2001 From: isalia20 Date: Wed, 28 Aug 2024 22:15:48 +0400 Subject: [PATCH 3/3] linter --- tests/geometry/test_boxes.py | 42 ++++++++++++------------------------ 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/tests/geometry/test_boxes.py b/tests/geometry/test_boxes.py index da4888ada47..5d98711027b 100644 --- a/tests/geometry/test_boxes.py +++ b/tests/geometry/test_boxes.py @@ -277,29 +277,14 @@ def apply_boxes_method(tensor: torch.Tensor, method: str, **kwargs): self.gradcheck(partial(apply_boxes_method, method="get_boxes_shape"), (t_boxes1,)) self.gradcheck(lambda x: Boxes.from_tensor(x, mode="xyxy_plus").data, (t_boxes_xyxy,)) self.gradcheck(lambda x: Boxes.from_tensor(x, mode="xywh").data, (t_boxes_xyxy1,)) - + def test_compute_area(self): # Rectangle - box_1 = [ - [0.0, 0.0], - [100.0, 0.0], - [100.0, 50.0], - [0.0, 50.0] - ] + box_1 = [[0.0, 0.0], [100.0, 0.0], [100.0, 50.0], [0.0, 50.0]] # Trapezoid - box_2 = [ - [0.0, 0.0], - [60.0, 0.0], - [40.0, 50.0], - [20.0, 50.0] - ] + box_2 = [[0.0, 0.0], [60.0, 0.0], [40.0, 50.0], [20.0, 50.0]] # Parallelogram - box_3 = [ - [0.0, 0.0], - [100.0, 0.0], - [120.0, 50.0], - [20.0, 50.0] - ] + box_3 = [[0.0, 0.0], [100.0, 0.0], [120.0, 50.0], [20.0, 50.0]] # Random quadrilateral box_4 = [ [50.0, 50.0], @@ -315,20 +300,21 @@ def test_compute_area(self): [0.0, 0.5], ] # Rectangle with minus coordinates - box_6 = [ - [-500.0, -500.0], - [-300.0, -500.0], - [-300.0, -300.0], - [-500.0, -300.0] - ] + box_6 = [[-500.0, -500.0], [-300.0, -500.0], [-300.0, -300.0], [-500.0, -300.0]] expected_values = [5000.0, 2000.0, 5000.0, 31925.0, 11287.5, 40000.0] box_coordinates = torch.tensor([box_1, box_2, box_3, box_4, box_5, box_6]) computed_areas = Boxes(box_coordinates).compute_area().tolist() - computed_areas_w_batch = Boxes(box_coordinates.reshape(2, 3, 4, 2)).compute_area().tolist() + computed_areas_w_batch = Boxes(box_coordinates.reshape(2, 3, 4, 2)).compute_area().tolist() flattened_computed_areas_w_batch = [area for batch in computed_areas_w_batch for area in batch] - assert all([computed_area == expected_area for computed_area, expected_area in zip(computed_areas, expected_values)]) - assert all([computed_area == expected_area for computed_area, expected_area in zip(flattened_computed_areas_w_batch, expected_values)]) + assert all( + computed_area == expected_area for computed_area, expected_area in zip(computed_areas, expected_values) + ) + assert all( + computed_area == expected_area + for computed_area, expected_area in zip(flattened_computed_areas_w_batch, expected_values) + ) + class TestTransformBoxes2D(BaseTester): def test_transform_boxes(self, device, dtype):