diff --git a/cellSAM/model.py b/cellSAM/model.py index 5f314e7..008ddf4 100644 --- a/cellSAM/model.py +++ b/cellSAM/model.py @@ -97,6 +97,7 @@ def segment_cellular_image( mask (np.array): Integer array with shape (H, W) x (np.array | None): Image embedding bounding_boxes (np.array | None): list of bounding boxes + scores (np.array): Confidence scores associated with each segmented cell """ if "cuda" in device: assert ( @@ -125,16 +126,18 @@ def segment_cellular_image( warn("No cells detected in the image.") return np.zeros(img.shape[1:], dtype=np.int32), None, None - segmentation_predictions, _, x, bounding_boxes = preds + segmentation_predictions, _, x, bounding_boxes, scores = preds if postprocess: segmentation_predictions = postprocess_predictions(segmentation_predictions) - mask = fill_holes_and_remove_small_masks(segmentation_predictions, min_size=25) + mask, removed_indices = fill_holes_and_remove_small_masks(segmentation_predictions, min_size=25) + scores = np.delete(scores, removed_indices) # Remove corresponding scores + if remove_boundaries: mask = subtract_boundaries(mask) - return mask, x.cpu().numpy(), bounding_boxes + return mask, x.cpu().numpy(), bounding_boxes, scores def postprocess_predictions(mask: np.ndarray): diff --git a/cellSAM/napari_plugin/_widget.py b/cellSAM/napari_plugin/_widget.py index af32fb4..90ea764 100644 --- a/cellSAM/napari_plugin/_widget.py +++ b/cellSAM/napari_plugin/_widget.py @@ -213,7 +213,7 @@ def _on_segment_all(self): inp = torch.from_numpy(inp).unsqueeze(0) - preds, _, x, _ = self._cellsam_model.predict( + preds, _, x, _, _ = self._cellsam_model.predict( inp.to(self._device), x=self._embedding, boxes_per_heatmap=None, @@ -224,7 +224,7 @@ def _on_segment_all(self): warn("No cells detected!") return - mask = fill_holes_and_remove_small_masks(preds, min_size=25) + mask, _ = fill_holes_and_remove_small_masks(preds, min_size=25) # Update the segmentation layer self._segmentation_layer.data = mask @@ -338,7 +338,7 @@ def _on_interactive_run(self, _: Optional[Any] = None) -> None: inp = torch.from_numpy(inp).unsqueeze(0) - preds, _, x, _ = self._cellsam_model.predict( + preds, _, x, _, _ = self._cellsam_model.predict( inp.to(self._device), x=self._embedding, boxes_per_heatmap=torch.tensor(formatted_boxes) @@ -349,7 +349,7 @@ def _on_interactive_run(self, _: Optional[Any] = None) -> None: if preds is None: warn("No cells detected!") else: - preds = fill_holes_and_remove_small_masks(preds, min_size=25) + preds, _ = fill_holes_and_remove_small_masks(preds, min_size=25) self._mask_layer.data = preds self._confirm_mask_btn.enabled = True self._cancel_annot_btn.enabled = True diff --git a/cellSAM/sam_inference.py b/cellSAM/sam_inference.py index 4146cbe..02d8b9d 100755 --- a/cellSAM/sam_inference.py +++ b/cellSAM/sam_inference.py @@ -330,4 +330,4 @@ def predict( # sum all masks, #TODO: double check if max is the right move here thresholded_masks_summed = np.max(thresholded_masks_summed, axis=0) - return thresholded_masks_summed, thresholded_masks, x, boxes_per_heatmap + return thresholded_masks_summed, thresholded_masks, x, boxes_per_heatmap, scores diff --git a/cellSAM/utils.py b/cellSAM/utils.py index af2e68f..f517f60 100644 --- a/cellSAM/utils.py +++ b/cellSAM/utils.py @@ -293,6 +293,9 @@ def fill_holes_and_remove_small_masks(masks, min_size=15): masks with holes filled and masks smaller than min_size removed, 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx] + removed_indices : list of int + List of original mask indices that were removed because they + were smaller than `min_size`. """ @@ -301,6 +304,7 @@ def fill_holes_and_remove_small_masks(masks, min_size=15): "masks_to_outlines takes 2D or 3D array, not %dD array" % masks.ndim ) + removed_indices = [] slices = find_objects(masks) j = 0 for i, slc in enumerate(slices): @@ -309,6 +313,7 @@ def fill_holes_and_remove_small_masks(masks, min_size=15): npix = msk.sum() if min_size > 0 and npix < min_size: masks[slc][msk] = 0 + removed_indices.append(i) elif npix > 0: if msk.ndim == 3: for k in range(msk.shape[0]): @@ -317,4 +322,4 @@ def fill_holes_and_remove_small_masks(masks, min_size=15): msk = binary_fill_holes(msk) masks[slc][msk] = j + 1 j += 1 - return masks + return masks, removed_indices