Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions metrics/f1/f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
- 'weighted': Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters `'macro'` to account for label imbalance. This option can result in an F-score that is not between precision and recall.
- 'samples': Calculate metrics for each instance, and find their average (only meaningful for multilabel classification).
sample_weight (`list` of `float`): Sample weights Defaults to None.
zero_division (`int` or `str`): Sets the value to return when there is a zero division. Defaults to 0.

- 0: Returns 0 when there is a zero division.
- 1: Returns 1 when there is a zero division.
- "warn": Raises a warning and returns 0 when there is a zero division.

Returns:
f1 (`float` or `array` of `float`): F1 score or list of f1 scores, depending on the value passed to `average`. Minimum possible value is 0. Maximum possible value is 1. Higher f1 scores are better.
Expand Down Expand Up @@ -123,8 +128,8 @@ def _info(self):
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"],
)

def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None):
def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None, zero_division=0):
score = f1_score(
references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight
references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight, zero_division=zero_division
)
return {"f1": score if getattr(score, "size", 1) > 1 else float(score)}