diff --git a/metrics/mean_iou/mean_iou.py b/metrics/mean_iou/mean_iou.py index 4c19864d..0911ad76 100644 --- a/metrics/mean_iou/mean_iou.py +++ b/metrics/mean_iou/mean_iou.py @@ -291,23 +291,19 @@ def _info(self): ], ) - def _compute( - self, - predictions, - references, - num_labels: int, - ignore_index: bool, - nan_to_num: Optional[int] = None, - label_map: Optional[Dict[int, int]] = None, - reduce_labels: bool = False, - ): - iou_result = mean_iou( - results=predictions, - gt_seg_maps=references, - num_labels=num_labels, - ignore_index=ignore_index, - nan_to_num=nan_to_num, - label_map=label_map, - reduce_labels=reduce_labels, - ) - return iou_result + def _compute(self, predictions, references, num_labels): + predictions = np.array(predictions) + references = np.array(references) + iou_list = [] + + for label in range(num_labels): + tp = np.sum((predictions == label) & (references == label)) + fp = np.sum((predictions == label) & (references != label)) + fn = np.sum((predictions != label) & (references == label)) + + denom = tp + fp + fn + 1e-10 # Prevent division by zero + iou = tp / denom if denom != 0 else 0.0 + iou_list.append(iou) + + mean_iou = np.mean(iou_list) + return {"mean_iou": mean_iou}