From 2d478b1b0dfb559ba4427ed4a688adedb2fc8b25 Mon Sep 17 00:00:00 2001 From: danikvh Date: Mon, 4 Aug 2025 18:46:53 +0200 Subject: [PATCH] Fix device mismatch in SAM preprocessing --- cellSAM/sam_inference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cellSAM/sam_inference.py b/cellSAM/sam_inference.py index 4146cbe..de7643b 100755 --- a/cellSAM/sam_inference.py +++ b/cellSAM/sam_inference.py @@ -120,9 +120,9 @@ def predict_transforms(self, imgs): return imgs - def sam_preprocess(self, x: torch.Tensor, return_paddings=False): + def sam_preprocess(self, x: torch.Tensor, return_paddings=False, device=None): """Normalize pixel values and pad to a square input.""" - x = (x - self.model.pixel_mean) / self.model.pixel_std + x = (x - self.model.pixel_mean.to(device)) / self.model.pixel_std.to(device) h, w = x.shape[-2:] padh = self.model.image_encoder.img_size - h @@ -140,7 +140,7 @@ def sam_bbox_preprocessing(self, x, device=None): x = [ torch.from_numpy(img).permute(2, 0, 1).contiguous().to(device) for img in x ] - x = [self.sam_preprocess(img, return_paddings=True) for img in x] + x = [self.sam_preprocess(img, return_paddings=True, device=device) for img in x] x, paddings = zip(*x) preprocessed_img = torch.stack(x, dim=0) return preprocessed_img, paddings