diff --git a/cellSAM/AnchorDETR/util/box_ops.py b/cellSAM/AnchorDETR/util/box_ops.py index 9ffe9a7..c62f143 100644 --- a/cellSAM/AnchorDETR/util/box_ops.py +++ b/cellSAM/AnchorDETR/util/box_ops.py @@ -81,8 +81,8 @@ def masks_to_boxes(masks): h, w = masks.shape[-2:] - y = torch.arange(0, h, dtype=torch.float) - x = torch.arange(0, w, dtype=torch.float) + y = torch.arange(0, h, dtype=torch.float, device=masks.device) + x = torch.arange(0, w, dtype=torch.float, device=masks.device) y, x = torch.meshgrid(y, x) x_mask = (masks * x.unsqueeze(0)) @@ -93,4 +93,4 @@ def masks_to_boxes(masks): y_max = y_mask.flatten(1).max(-1)[0] y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] - return torch.stack([x_min, y_min, x_max, y_max], 1) + return torch.stack([x_min, y_min, x_max, y_max], 1) \ No newline at end of file