diff --git a/src/lenskit/sklearn/nmf.py b/src/lenskit/sklearn/nmf.py index 7733cef8b..e6ff897d6 100644 --- a/src/lenskit/sklearn/nmf.py +++ b/src/lenskit/sklearn/nmf.py @@ -15,7 +15,7 @@ import numpy as np from pydantic import AliasChoices, BaseModel, Field -from sklearn.decomposition import non_negative_factorization +from sklearn.decomposition import MiniBatchNMF, non_negative_factorization from typing_extensions import Literal, override from lenskit.data import Dataset, ItemList, QueryInput, RecQuery @@ -44,6 +44,7 @@ class NMFConfig(BaseModel, extra="forbid"): alpha_W: float = 0.0 alpha_H: float | Literal["same"] = "same" l1_ratio: float = 0.0 + method: Literal["full", "minibatch"] = "full" class NMFScorer(Component[ItemList], Trainable): @@ -73,16 +74,31 @@ def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()): _log.info("[%s] sparsifying and normalizing matrix", timer) r_mat = data.interactions().matrix().scipy(layout="csr", legacy=True) - _log.info("[%s] training NMF", timer) - W, H, n_iter = non_negative_factorization( - r_mat, - beta_loss=self.config.beta_loss, - max_iter=self.config.max_iter, - n_components=self.config.n_components, - alpha_W=self.config.alpha_W, - alpha_H=self.config.alpha_H, - l1_ratio=self.config.l1_ratio, - ) + _log.info("[%s] training NMF (%s)", timer, self.config.method) + + if self.config.method == "full": + W, H, n_iter = non_negative_factorization( + r_mat, + beta_loss=self.config.beta_loss, + max_iter=self.config.max_iter, + n_components=self.config.n_components, + alpha_W=self.config.alpha_W, + alpha_H=self.config.alpha_H, + l1_ratio=self.config.l1_ratio, + ) + else: # minibatch + model = MiniBatchNMF( + beta_loss=self.config.beta_loss, + max_iter=self.config.max_iter, + n_components=self.config.n_components, + alpha_W=self.config.alpha_W, + alpha_H=self.config.alpha_H, + l1_ratio=self.config.l1_ratio, + ) + W = model.fit_transform(r_mat) + H = model.components_ + n_iter = model.n_iter_ + _log.info("[%s] Trained NMF in %d iterations", timer, n_iter) self.user_components = np.require(W, dtype=np.float32) diff --git a/tests/sklearn/test_nmf.py b/tests/sklearn/test_nmf.py index 8f9e98092..893830566 100644 --- a/tests/sklearn/test_nmf.py +++ b/tests/sklearn/test_nmf.py @@ -31,48 +31,40 @@ class TestNMF(BasicComponentTests, ScorerTests): expected_ndcg = 0.22 -def test_nmf_basic_build(): - algo = nmf.NMFScorer(n_components=2) +@mark.parametrize("method", ["full", "minibatch"]) +def test_nmf_basic_build(method): + algo = nmf.NMFScorer(n_components=2, method=method) algo.train(simple_ds) assert algo.user_components.shape == (3, 2) -def test_nmf_predict_basic(): - _log.info("NMF input data:\n%s", simple_df) - algo = nmf.NMFScorer(n_components=2) +@mark.parametrize("method", ["full", "minibatch"]) +def test_nmf_predict_basic(method): + algo = nmf.NMFScorer(n_components=2, method=method) algo.train(simple_ds) - _log.info("user matrix:\n%s", str(algo.user_components)) - _log.info("item matrix:\n%s", str(algo.item_components)) preds = algo(10, ItemList([3])) - assert len(preds) == 1 preds = preds.scores("pandas", index="ids") - assert preds is not None assert preds.index[0] == 3 - assert preds.loc[3] >= 0 - assert preds.loc[3] <= 5 + assert 0 <= preds.loc[3] <= 5 -def test_nmf_predict_bad_item(): - algo = nmf.NMFScorer(n_components=2) +@mark.parametrize("method", ["full", "minibatch"]) +def test_nmf_predict_bad_item(method): + algo = nmf.NMFScorer(n_components=2, method=method) algo.train(simple_ds) preds = algo(10, ItemList([4])) - assert len(preds) == 1 preds = preds.scores("pandas", index="ids") - assert preds is not None - assert preds.index[0] == 4 assert np.isnan(preds.loc[4]) -def test_nmf_predict_bad_user(): - algo = nmf.NMFScorer(n_components=2) +@mark.parametrize("method", ["full", "minibatch"]) +def test_nmf_predict_bad_user(method): + algo = nmf.NMFScorer(n_components=2, method=method) algo.train(simple_ds) preds = algo(50, ItemList([3])) - assert len(preds) == 1 preds = preds.scores("pandas", index="ids") - assert preds is not None - assert preds.index[0] == 3 assert np.isnan(preds.loc[3]) diff --git a/uv.lock b/uv.lock index d57a4f3c0..e2a89b5c5 100644 --- a/uv.lock +++ b/uv.lock @@ -2326,7 +2326,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform != 'darwin'" }, + { name = "nvidia-cublas-cu12", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" }, @@ -2339,7 +2339,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'darwin'" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, @@ -2371,9 +2371,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform != 'darwin'" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform != 'darwin'" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'darwin'" }, + { name = "nvidia-cublas-cu12", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, + { name = "nvidia-cusparse-cu12", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, @@ -2386,7 +2386,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'darwin'" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, @@ -4268,7 +4268,7 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools", marker = "sys_platform != 'darwin'" }, + { name = "setuptools", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/7d/39/43325b3b651d50187e591eefa22e236b2981afcebaefd4f2fc0ea99df191/triton-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b70f5e6a41e52e48cfc087436c8a28c17ff98db369447bcaff3b887a3ab4467", size = 155531138, upload-time = "2025-07-30T19:58:29.908Z" },