Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 1 addition & 69 deletions src/lenskit/metrics/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def summarize(
}


class GlobalMetric(Metric):
class GlobalMetric:
"""
Base class for metrics that measure entire runs at a time.

Expand All @@ -172,71 +172,3 @@ def measure_run(self, output: ItemListCollection, test: ItemListCollection, /) -
Individual metric classes need to implement this method.
"""
raise NotImplementedError() # pragma: no cover

def measure_list(self, output: ItemList, test: ItemList, /) -> Any:
raise NotImplementedError("Global metrics don't support per-list measurement")

def summarize(self, values: list[Any] | pa.Array | pa.ChunkedArray, /) -> float:
raise NotImplementedError("Global metrics should implement measure_run instead")


class DecomposedMetric(Metric):
"""
Deprecated base class for decomposed metrics.

.. deprecated:: 2025.4
This class is deprecated and its functionality has been moved to :class:`Metric`.
It is scheduled for removal in 2026.

Base class for metrics that measure entire runs through flexible
aggregations of per-list intermediate measurements. They can optionally
extract individual-list metrics from the per-list measurements.

Stability:
Full
"""

def measure_list(self, output: ItemList, test: ItemList, /) -> Any:
return self.compute_list_data(output, test)

def extract_list_metrics(self, data: Any, /) -> float | None:
return self.extract_list_metric(data)

def summarize(self, values: list[Any] | pa.Array | pa.ChunkedArray, /) -> dict[str, float]:
if isinstance(values, (pa.Array, pa.ChunkedArray)):
values = values.to_pylist()
result = self.global_aggregate(values)
if isinstance(result, (float, int, np.floating, np.integer)):
return {"value": float(result)}
return result

@abstractmethod
def compute_list_data(self, output: ItemList, test: ItemList, /) -> Any:
"""
Compute measurements for a single list.

Use `measure_list` in `Metric` for new implementations.
"""
raise NotImplementedError() # pragma: no cover

def extract_list_metric(self, data: Any, /) -> float | None:
"""
Extract a single-list metric from the per-list measurement result (if
applicable).

Returns:
The per-list metric, or ``None`` if this metric does not compute
per-list metrics.

Implement :meth:`Metric.extract_list_metrics` in new implementations.
"""
return None

@abstractmethod
def global_aggregate(self, values: list[Any], /) -> float | dict[str, float]:
"""
Aggregate list metrics to compute a global value.

Implement :meth:`Metric.summarize` in new implementations.
"""
raise NotImplementedError() # pragma: no cover
7 changes: 1 addition & 6 deletions src/lenskit/metrics/_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from lenskit.data import ItemList, ItemListCollection

from ._base import DecomposedMetric, GlobalMetric, ListMetric, Metric, MetricFunction
from ._base import GlobalMetric, ListMetric, Metric, MetricFunction

_log = logging.getLogger(__name__)
K1 = TypeVar("K1", bound=tuple)
Expand Down Expand Up @@ -46,11 +46,6 @@ def is_global(self) -> bool:
"Check if this metric is global."
return isinstance(self.metric, GlobalMetric)

@property
def is_decomposed(self) -> bool:
"Check if this metric is decomposed."
return isinstance(self.metric, DecomposedMetric)

def measure_list(self, list: ItemList, test: ItemList) -> Any:
"""Get intermediate measurement data from the metric."""
if isinstance(self.metric, Callable):
Expand Down
61 changes: 3 additions & 58 deletions src/lenskit/metrics/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from lenskit.data.adapt import ITEM_COMPAT_COLUMN, normalize_columns
from lenskit.data.types import AliasedColumn

from ._base import DecomposedMetric, ListMetric, Metric
from ._base import ListMetric, Metric

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,7 +108,7 @@ def align_scores(
return pred_s, rate_s


class RMSE(PredictMetric, ListMetric, DecomposedMetric):
class RMSE(PredictMetric, ListMetric):
"""
Compute RMSE (root mean squared error). This is computed as:

Expand All @@ -131,36 +131,8 @@ def measure_list(self, predictions: ItemList, test: ItemList | None = None, /) -
err *= err
return np.sqrt(np.mean(err))

@override
def compute_list_data(self, output, test):
ps, ts = self.align_scores(output, test)
err = ps - ts
err *= err
return np.sum(err), len(err)

@override
def extract_list_metric(self, metric):
tot, n = metric
if n > 0:
return np.sqrt(tot / n)
else:
return np.nan

@override
def global_aggregate(self, values):
tot_sqerr = 0.0
tot_n = 0.0
for t, n in values:
tot_sqerr += t
tot_n += n

if tot_n > 0:
return np.sqrt(tot_sqerr / tot_n)
else:
return np.nan


class MAE(PredictMetric, ListMetric, DecomposedMetric):
class MAE(PredictMetric, ListMetric):
"""
Compute MAE (mean absolute error). This is computed as:

Expand All @@ -181,30 +153,3 @@ def measure_list(self, predictions: ItemList, test: ItemList | None = None, /) -
ps, ts = self.align_scores(predictions, test)
err = ps - ts
return np.mean(np.abs(err)).item()

@override
def compute_list_data(self, output, test):
ps, ts = self.align_scores(output, test)
err = ps - ts
return np.sum(np.abs(err)), len(err)

@override
def extract_list_metric(self, metric):
tot, n = metric
if n > 0:
return tot / n
else:
return np.nan

@override
def global_aggregate(self, values):
tot_err = 0.0
tot_n = 0.0
for t, n in values:
tot_err += t
tot_n += n

if n > 0:
return tot_err / tot_n
else:
return np.nan
4 changes: 2 additions & 2 deletions src/lenskit/metrics/ranking/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from lenskit.data import ItemList

from .._base import DecomposedMetric, GlobalMetric, ListMetric, Metric
from .._base import ListMetric, Metric

__all__ = ["Metric", "ListMetric", "GlobalMetric", "DecomposedMetric", "RankingMetricBase"]
__all__ = ["Metric", "ListMetric", "RankingMetricBase"]


class RankingMetricBase(Metric):
Expand Down
12 changes: 6 additions & 6 deletions src/lenskit/metrics/ranking/_gini.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from lenskit.logging import get_logger
from lenskit.stats import gini

from ._base import DecomposedMetric, RankingMetricBase
from ._base import RankingMetricBase
from ._weighting import GeometricRankWeight, RankWeight

_log = get_logger(__name__)


class GiniBase(DecomposedMetric, RankingMetricBase):
class GiniBase(RankingMetricBase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does RankingMetricBase inherit from ListMetric? Gini is not a listwise metric, so if it does then we should either:

  • Make RankingMetricBase no longer inherit from ListMetric, and update all the ranking metrics that are listwise metrics to inherit from both RankingMetricBase and ListMetric.
  • Rename RankingMetricBase to RankingListMetricBase and make GiniBase no longer inherit from RankingMetricBase
  • A combination of the two: introduce RankingListMetricBase to be both ranking metric and list metric, and make the other metrics inherit from it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RankingMetricBase inherits from Metric.

"""
Base class for Gini diversity / popularity concentration metrics.
"""
Expand Down Expand Up @@ -66,12 +66,12 @@ class ListGini(GiniBase):
"""

@override
def compute_list_data(self, output: ItemList, test):
def measure_list(self, output: ItemList, test):
recs = self.truncate(output)
return recs.ids(format="arrow")

@override
def global_aggregate(self, values: list[pa.Array]):
def summarize(self, values: list[pa.Array] | pa.ChunkedArray, /):
log = _log.bind(metric=self.label, item_count=self.item_count)
log.debug("aggregating for %d lists", len(values))
chunked = pa.chunked_array(values)
Expand Down Expand Up @@ -119,13 +119,13 @@ def __init__(
self.weight = weight

@override
def compute_list_data(self, output: ItemList, test):
def measure_list(self, output: ItemList, test):
recs = self.truncate(output)
weights = self.weight.weight(np.arange(1, len(recs) + 1))
return (recs.ids(format="arrow"), pa.array(weights, type=pa.float32()))

@override
def global_aggregate(self, values: list[tuple[pa.Array, pa.FloatArray]]):
def summarize(self, values: list[tuple[pa.Array, pa.FloatArray]]):
log = _log.bind(metric=self.label, item_count=self.item_count)
log.debug("aggregating for %d lists", len(values))
table = pa.Table.from_batches(
Expand Down
56 changes: 1 addition & 55 deletions tests/eval/test_measurement_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from lenskit.basic import PopScorer
from lenskit.data import ItemList, ItemListCollection
from lenskit.metrics import NDCG, Recall
from lenskit.metrics._base import DecomposedMetric, GlobalMetric, ListMetric, Metric
from lenskit.metrics._base import GlobalMetric, ListMetric, Metric
from lenskit.metrics._collect import MeasurementCollector, MetricWrapper
from lenskit.metrics.basic import ListLength
from lenskit.splitting import split_temporal_fraction
Expand Down Expand Up @@ -270,28 +270,6 @@ def summarize(self, values):
# metricWrapper properties and summarization


def test_metricwrapper_is_decomposed_property():
class DummyDecomposed(DecomposedMetric):
label = "dummy_decomp"

def compute_list_data(self, recs, test):
return {"a": 1.0}

def global_aggregate(self, values):
return {"mean": 1.0}

def measure_list(self, recs, test):
return {"a": 1.0}

def summarize(self, values):
return {"mean": 1.0}

wrapper = MetricWrapper(DummyDecomposed(), "decomp")
assert wrapper.is_decomposed
wrapper_non = MetricWrapper(ListLength(), "len")
assert not wrapper_non.is_decomposed


def test_measure_metric_with_none_summarize():
"""Test metric that returns None from summarize."""

Expand Down Expand Up @@ -347,25 +325,6 @@ def test_full_workflow_integration_improved(ml_ds):
assert 0 <= value <= 1


# test that global metric raises errors for unsupported operations


def test_global_metric_unsupported():
class AnotherGlobalMetric(GlobalMetric):
label = "global"

def measure_run(self, run, test):
return 1.0

metric = AnotherGlobalMetric()

with raises(NotImplementedError, match="Global metrics don't support per-list measurement"):
metric.measure_list(ItemList([1, 2]), ItemList([1]))

with raises(NotImplementedError, match="Global metrics should implement measure_run instead"):
metric.summarize([1, 2, 3])


# test edge cases in Metric.summarize


Expand All @@ -390,19 +349,6 @@ def measure_list(self, output, test):
assert result["std"] == 1.0


def test_decomposed_metric_numeric_return():
class TestDecomposedMetric(DecomposedMetric):
def compute_list_data(self, output, test):
return len(output)

def global_aggregate(self, values):
return 5.0

metric = TestDecomposedMetric()
result = metric.summarize([1, 2, 3])
assert result == {"value": 5.0}


def test_empty_intermediate_values():
class TestMetric(Metric):
label = "test"
Expand Down
Loading