diff --git a/cellSAM/model.py b/cellSAM/model.py index 23283be..95bbe9c 100644 --- a/cellSAM/model.py +++ b/cellSAM/model.py @@ -144,9 +144,9 @@ def segment_cellular_image( model, img = model.to(device), img.to(device) preds = model.predict(img, x=None, boxes_per_heatmap=bounding_boxes, device=device, fast=fast) - if preds is None: + if preds[0] is None: warn("No cells detected in the image.") - return np.zeros(img.shape[1:], dtype=np.int32), None, None + return np.zeros(img.shape[-2:], dtype=np.uint8), None, torch.empty((1, 4)) segmentation_predictions, _, x, bounding_boxes = preds