diff --git a/kornia/geometry/boxes.py b/kornia/geometry/boxes.py index c08ba867d51..a9bd12ae506 100644 --- a/kornia/geometry/boxes.py +++ b/kornia/geometry/boxes.py @@ -392,9 +392,19 @@ def filter_boxes_by_area( def compute_area(self) -> torch.Tensor: """Returns :math:`(B, N)`.""" - w = self._data[..., 1, 0] - self._data[..., 0, 0] - h = self._data[..., 2, 1] - self._data[..., 0, 1] - return (w * h).unsqueeze(0) if self._data.ndim == 3 else (w * h) + coords = self._data.view((-1, 4, 2)) if self._data.ndim == 4 else self._data + # calculate centroid of the box + centroid = coords.mean(dim=1, keepdim=True) + # calculate the angle from centroid to each corner + angles = torch.atan2(coords[..., 1] - centroid[..., 1], coords[..., 0] - centroid[..., 0]) + # sort the corners by angle to get an order for shoelace formula + _, clockwise_indices = torch.sort(angles, dim=1, descending=True) + # gather the corners in the new order + ordered_corners = torch.gather(coords, 1, clockwise_indices.unsqueeze(-1).expand(-1, -1, 2)) + x, y = ordered_corners[..., 0], ordered_corners[..., 1] + # Gaussian/Shoelace formula https://en.wikipedia.org/wiki/Shoelace_formula + area = 0.5 * torch.abs(torch.sum((x * torch.roll(y, 1, 1)) - (y * torch.roll(x, 1, 1)), dim=1)) + return area.view(self._data.shape[:2]) if self._data.ndim == 4 else area @classmethod def from_tensor( diff --git a/tests/geometry/test_boxes.py b/tests/geometry/test_boxes.py index 3f4f8c706dc..5d98711027b 100644 --- a/tests/geometry/test_boxes.py +++ b/tests/geometry/test_boxes.py @@ -278,6 +278,43 @@ def apply_boxes_method(tensor: torch.Tensor, method: str, **kwargs): self.gradcheck(lambda x: Boxes.from_tensor(x, mode="xyxy_plus").data, (t_boxes_xyxy,)) self.gradcheck(lambda x: Boxes.from_tensor(x, mode="xywh").data, (t_boxes_xyxy1,)) + def test_compute_area(self): + # Rectangle + box_1 = [[0.0, 0.0], [100.0, 0.0], [100.0, 50.0], [0.0, 50.0]] + # Trapezoid + box_2 = [[0.0, 0.0], [60.0, 0.0], [40.0, 50.0], [20.0, 50.0]] + # Parallelogram + box_3 = [[0.0, 0.0], [100.0, 0.0], [120.0, 50.0], [20.0, 50.0]] + # Random quadrilateral + box_4 = [ + [50.0, 50.0], + [150.0, 250.0], + [0.0, 500.0], + [27.0, 80], + ] + # Random quadrilateral + box_5 = [ + [0.0, 0.0], + [150.0, 0.0], + [150.0, 150.0], + [0.0, 0.5], + ] + # Rectangle with minus coordinates + box_6 = [[-500.0, -500.0], [-300.0, -500.0], [-300.0, -300.0], [-500.0, -300.0]] + + expected_values = [5000.0, 2000.0, 5000.0, 31925.0, 11287.5, 40000.0] + box_coordinates = torch.tensor([box_1, box_2, box_3, box_4, box_5, box_6]) + computed_areas = Boxes(box_coordinates).compute_area().tolist() + computed_areas_w_batch = Boxes(box_coordinates.reshape(2, 3, 4, 2)).compute_area().tolist() + flattened_computed_areas_w_batch = [area for batch in computed_areas_w_batch for area in batch] + assert all( + computed_area == expected_area for computed_area, expected_area in zip(computed_areas, expected_values) + ) + assert all( + computed_area == expected_area + for computed_area, expected_area in zip(flattened_computed_areas_w_batch, expected_values) + ) + class TestTransformBoxes2D(BaseTester): def test_transform_boxes(self, device, dtype):