From 4cbeded56f857469ba6399e56c943ab9cc7b8d43 Mon Sep 17 00:00:00 2001 From: cyw37 Date: Tue, 2 Sep 2025 20:21:47 -0400 Subject: [PATCH 1/3] Added minibatchNMF and added scikit-learn to packages --- pyproject.toml | 1 + src/lenskit/sklearn/nmf.py | 38 +++++++++++++++++++++++++++----------- tests/sklearn/test_nmf.py | 36 ++++++++++++++---------------------- uv.lock | 4 +++- 4 files changed, 45 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 45b684ce1..51e0cf230 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "prettytable ~=3.14", "xopen ~=2.0", "lazy-loader >=0.4", + "scikit-learn>=1.7.1", ] [dependency-groups] diff --git a/src/lenskit/sklearn/nmf.py b/src/lenskit/sklearn/nmf.py index 7733cef8b..24456ae05 100644 --- a/src/lenskit/sklearn/nmf.py +++ b/src/lenskit/sklearn/nmf.py @@ -15,7 +15,7 @@ import numpy as np from pydantic import AliasChoices, BaseModel, Field -from sklearn.decomposition import non_negative_factorization +from sklearn.decomposition import non_negative_factorization, MiniBatchNMF from typing_extensions import Literal, override from lenskit.data import Dataset, ItemList, QueryInput, RecQuery @@ -44,6 +44,7 @@ class NMFConfig(BaseModel, extra="forbid"): alpha_W: float = 0.0 alpha_H: float | Literal["same"] = "same" l1_ratio: float = 0.0 + method: Literal["full", "minibatch"] = "full" class NMFScorer(Component[ItemList], Trainable): @@ -73,16 +74,31 @@ def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()): _log.info("[%s] sparsifying and normalizing matrix", timer) r_mat = data.interactions().matrix().scipy(layout="csr", legacy=True) - _log.info("[%s] training NMF", timer) - W, H, n_iter = non_negative_factorization( - r_mat, - beta_loss=self.config.beta_loss, - max_iter=self.config.max_iter, - n_components=self.config.n_components, - alpha_W=self.config.alpha_W, - alpha_H=self.config.alpha_H, - l1_ratio=self.config.l1_ratio, - ) + _log.info("[%s] training NMF (%s)", timer, self.config.method) + + if self.config.method == "full": + W, H, n_iter = non_negative_factorization( + r_mat, + beta_loss=self.config.beta_loss, + max_iter=self.config.max_iter, + n_components=self.config.n_components, + alpha_W=self.config.alpha_W, + alpha_H=self.config.alpha_H, + l1_ratio=self.config.l1_ratio, + ) + else: # minibatch + model = MiniBatchNMF( + beta_loss=self.config.beta_loss, + max_iter=self.config.max_iter, + n_components=self.config.n_components, + alpha_W=self.config.alpha_W, + alpha_H=self.config.alpha_H, + l1_ratio=self.config.l1_ratio, + ) + W = model.fit_transform(r_mat) + H = model.components_ + n_iter = model.n_iter_ + _log.info("[%s] Trained NMF in %d iterations", timer, n_iter) self.user_components = np.require(W, dtype=np.float32) diff --git a/tests/sklearn/test_nmf.py b/tests/sklearn/test_nmf.py index 8f9e98092..12d658441 100644 --- a/tests/sklearn/test_nmf.py +++ b/tests/sklearn/test_nmf.py @@ -31,48 +31,40 @@ class TestNMF(BasicComponentTests, ScorerTests): expected_ndcg = 0.22 -def test_nmf_basic_build(): - algo = nmf.NMFScorer(n_components=2) +@mark.parametrize("method", ["full", "minibatch"]) +def test_nmf_basic_build(method): + algo = nmf.NMFScorer(n_components=2, method=method) algo.train(simple_ds) assert algo.user_components.shape == (3, 2) -def test_nmf_predict_basic(): - _log.info("NMF input data:\n%s", simple_df) - algo = nmf.NMFScorer(n_components=2) +@mark.parametrize("method", ["full", "minibatch"]) +def test_nmf_predict_basic(method): + algo = nmf.NMFScorer(n_components=2, method=method) algo.train(simple_ds) - _log.info("user matrix:\n%s", str(algo.user_components)) - _log.info("item matrix:\n%s", str(algo.item_components)) preds = algo(10, ItemList([3])) - assert len(preds) == 1 preds = preds.scores("pandas", index="ids") - assert preds is not None assert preds.index[0] == 3 - assert preds.loc[3] >= 0 - assert preds.loc[3] <= 5 + assert 0 <= preds.loc[3] <= 5 -def test_nmf_predict_bad_item(): - algo = nmf.NMFScorer(n_components=2) +@mark.parametrize("method", ["full", "minibatch"]) +def test_nmf_predict_bad_item(method): + algo = nmf.NMFScorer(n_components=2, method=method) algo.train(simple_ds) preds = algo(10, ItemList([4])) - assert len(preds) == 1 preds = preds.scores("pandas", index="ids") - assert preds is not None - assert preds.index[0] == 4 assert np.isnan(preds.loc[4]) -def test_nmf_predict_bad_user(): - algo = nmf.NMFScorer(n_components=2) +@mark.parametrize("method", ["full", "minibatch"]) +def test_nmf_predict_bad_user(method): + algo = nmf.NMFScorer(n_components=2, method=method) algo.train(simple_ds) preds = algo(50, ItemList([3])) - assert len(preds) == 1 preds = preds.scores("pandas", index="ids") - assert preds is not None - assert preds.index[0] == 3 - assert np.isnan(preds.loc[3]) + assert np.isnan(preds.loc[3]) \ No newline at end of file diff --git a/uv.lock b/uv.lock index dfaff33ec..03181fc34 100644 --- a/uv.lock +++ b/uv.lock @@ -1324,6 +1324,7 @@ dependencies = [ { name = "pyzmq" }, { name = "requests" }, { name = "rich" }, + { name = "scikit-learn" }, { name = "scipy" }, { name = "structlog" }, { name = "threadpoolctl" }, @@ -1472,6 +1473,7 @@ requires-dist = [ { name = "ray", marker = "extra == 'ray'", specifier = "~=2.42" }, { name = "requests", specifier = "~=2.28" }, { name = "rich", specifier = ">=13.5" }, + { name = "scikit-learn", specifier = ">=1.7.1" }, { name = "scikit-learn", marker = "extra == 'sklearn'", specifier = "~=1.2" }, { name = "scipy", specifier = ">=1.11" }, { name = "structlog", specifier = ">=23.2" }, @@ -3561,7 +3563,7 @@ resolution-markers = [ "python_full_version < '3.12' and sys_platform == 'darwin'", ] dependencies = [ - { name = "mpmath" }, + { name = "mpmath", marker = "extra == 'group-7-lenskit-cpu' or extra != 'group-7-lenskit-gpu'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } wheels = [ From 48cf85a3a23a126857d4fc3d96310e541563d5b5 Mon Sep 17 00:00:00 2001 From: cyw37 Date: Wed, 3 Sep 2025 14:03:17 -0400 Subject: [PATCH 2/3] Removed scikit-learn from dependencies --- pyproject.toml | 1 - uv.lock | 18 ++++++++---------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e6d2e03a1..8f279d0db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,6 @@ dependencies = [ "prettytable ~=3.14", "xopen ~=2.0", "lazy-loader >=0.4", - "scikit-learn>=1.7.1", ] [dependency-groups] diff --git a/uv.lock b/uv.lock index 1ffaea9ba..04fd49f2a 100644 --- a/uv.lock +++ b/uv.lock @@ -1463,7 +1463,6 @@ dependencies = [ { name = "pyzmq" }, { name = "requests" }, { name = "rich" }, - { name = "scikit-learn" }, { name = "scipy" }, { name = "structlog" }, { name = "threadpoolctl" }, @@ -1609,7 +1608,6 @@ requires-dist = [ { name = "ray", marker = "extra == 'ray'", specifier = "~=2.42" }, { name = "requests", specifier = "~=2.28" }, { name = "rich", specifier = ">=13.5" }, - { name = "scikit-learn", specifier = ">=1.7.1" }, { name = "scikit-learn", marker = "extra == 'sklearn'", specifier = "~=1.2" }, { name = "scipy", specifier = ">=1.11" }, { name = "structlog", specifier = ">=23.2" }, @@ -2286,7 +2284,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform != 'darwin'" }, + { name = "nvidia-cublas-cu12", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" }, @@ -2299,7 +2297,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'darwin'" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, @@ -2331,9 +2329,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform != 'darwin'" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform != 'darwin'" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'darwin'" }, + { name = "nvidia-cublas-cu12", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, + { name = "nvidia-cusparse-cu12", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, @@ -2346,7 +2344,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform != 'darwin'" }, + { name = "nvidia-nvjitlink-cu12", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, @@ -3918,7 +3916,7 @@ name = "sympy" version = "1.14.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "mpmath", marker = "extra == 'group-7-lenskit-cpu' or extra != 'group-7-lenskit-gpu'" }, + { name = "mpmath" }, ] sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } wheels = [ @@ -4228,7 +4226,7 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools", marker = "sys_platform != 'darwin'" }, + { name = "setuptools", marker = "(sys_platform != 'darwin' and extra != 'group-7-lenskit-cpu') or (extra == 'group-7-lenskit-cpu' and extra == 'group-7-lenskit-gpu')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/7d/39/43325b3b651d50187e591eefa22e236b2981afcebaefd4f2fc0ea99df191/triton-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b70f5e6a41e52e48cfc087436c8a28c17ff98db369447bcaff3b887a3ab4467", size = 155531138, upload-time = "2025-07-30T19:58:29.908Z" }, From bdb66275d2f9b1f1858b8508d4f8f52e04a0357b Mon Sep 17 00:00:00 2001 From: cyw37 Date: Mon, 8 Sep 2025 11:05:42 -0400 Subject: [PATCH 3/3] sorted import and ran pre-commit --- src/lenskit/sklearn/nmf.py | 2 +- tests/sklearn/test_nmf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lenskit/sklearn/nmf.py b/src/lenskit/sklearn/nmf.py index 24456ae05..e6ff897d6 100644 --- a/src/lenskit/sklearn/nmf.py +++ b/src/lenskit/sklearn/nmf.py @@ -15,7 +15,7 @@ import numpy as np from pydantic import AliasChoices, BaseModel, Field -from sklearn.decomposition import non_negative_factorization, MiniBatchNMF +from sklearn.decomposition import MiniBatchNMF, non_negative_factorization from typing_extensions import Literal, override from lenskit.data import Dataset, ItemList, QueryInput, RecQuery diff --git a/tests/sklearn/test_nmf.py b/tests/sklearn/test_nmf.py index 12d658441..893830566 100644 --- a/tests/sklearn/test_nmf.py +++ b/tests/sklearn/test_nmf.py @@ -67,4 +67,4 @@ def test_nmf_predict_bad_user(method): preds = algo(50, ItemList([3])) preds = preds.scores("pandas", index="ids") - assert np.isnan(preds.loc[3]) \ No newline at end of file + assert np.isnan(preds.loc[3])