diff --git a/metrics/f1/f1.py b/metrics/f1/f1.py index 05b0baad..46a6f9ac 100644 --- a/metrics/f1/f1.py +++ b/metrics/f1/f1.py @@ -39,6 +39,11 @@ - 'weighted': Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters `'macro'` to account for label imbalance. This option can result in an F-score that is not between precision and recall. - 'samples': Calculate metrics for each instance, and find their average (only meaningful for multilabel classification). sample_weight (`list` of `float`): Sample weights Defaults to None. + zero_division (`int` or `str`): Sets the value to return when there is a zero division. Defaults to 0. + + - 0: Returns 0 when there is a zero division. + - 1: Returns 1 when there is a zero division. + - "warn": Raises a warning and returns 0 when there is a zero division. Returns: f1 (`float` or `array` of `float`): F1 score or list of f1 scores, depending on the value passed to `average`. Minimum possible value is 0. Maximum possible value is 1. Higher f1 scores are better. @@ -123,8 +128,8 @@ def _info(self): reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"], ) - def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None): + def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None, zero_division=0): score = f1_score( - references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight + references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight, zero_division=zero_division ) return {"f1": score if getattr(score, "size", 1) > 1 else float(score)}