diff --git a/Cargo.toml b/Cargo.toml index 3c71b4bf4..591f1bfe1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,3 +42,6 @@ opt-level = 2 [profile.release] lto = true debug = "line-tables-only" + +[features] +fuse-parallel = [] diff --git a/pipelines/slim.toml b/pipelines/slim.toml new file mode 100644 index 000000000..6052696dd --- /dev/null +++ b/pipelines/slim.toml @@ -0,0 +1,8 @@ +[meta] +name = "slim" + +[options] +base = "std:topn" + +[components.scorer] +class = "lenskit.knn.SLIMScorer" diff --git a/src/accel/als/explicit.rs b/src/accel/als/explicit.rs index 9f7cf494f..5c3731ef1 100644 --- a/src/accel/als/explicit.rs +++ b/src/accel/als/explicit.rs @@ -17,6 +17,7 @@ use rayon::prelude::*; use log::*; use crate::{ + parallel::maybe_fuse, progress::ProgressHandle, sparse::{CSRMatrix, CSR}, }; @@ -46,9 +47,7 @@ pub(super) fn train_explicit_matrix<'py>( ); let frob: f32 = py.allow_threads(|| { - this.outer_iter_mut() - .into_par_iter() - .enumerate() + maybe_fuse(this.outer_iter_mut().into_par_iter().enumerate()) .map(|(i, row)| { let f = train_row_solve(&matrix, i, row, &other, reg); progress.tick(); diff --git a/src/accel/als/implicit.rs b/src/accel/als/implicit.rs index ecd0c8443..31cf2e5e8 100644 --- a/src/accel/als/implicit.rs +++ b/src/accel/als/implicit.rs @@ -17,6 +17,7 @@ use rayon::prelude::*; use log::*; use crate::{ + parallel::maybe_fuse, progress::ProgressHandle, sparse::{CSRMatrix, CSR}, }; @@ -48,9 +49,7 @@ pub(super) fn train_implicit_matrix<'py>( other.nrows() ); let frob: f32 = py.allow_threads(|| { - this.outer_iter_mut() - .into_par_iter() - .enumerate() + maybe_fuse(this.outer_iter_mut().into_par_iter().enumerate()) .map(|(i, row)| { let f = train_row_solve(&matrix, i, row, &other, &otor); progress.tick(); diff --git a/src/accel/lib.rs b/src/accel/lib.rs index 649ad63b7..acb82fb39 100644 --- a/src/accel/lib.rs +++ b/src/accel/lib.rs @@ -4,10 +4,7 @@ // Licensed under the MIT license, see LICENSE.md for details. // SPDX-License-Identifier: MIT -use log::*; -use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; -use rayon::{current_num_threads, ThreadPoolBuilder}; mod als; mod arrow; @@ -16,7 +13,9 @@ mod data; mod funksvd; mod indirect_hashing; mod knn; +mod parallel; mod progress; +mod slim; mod sparse; /// Entry point for LensKit accelerator module. @@ -26,29 +25,13 @@ fn _accel(m: &Bound<'_, PyModule>) -> PyResult<()> { knn::register_knn(m)?; als::register_als(m)?; data::register_data(m)?; + slim::register_slim(m)?; m.add_class::()?; - m.add_function(wrap_pyfunction!(init_accel_pool, m)?)?; - m.add_function(wrap_pyfunction!(thread_count, m)?)?; + m.add_function(wrap_pyfunction!(parallel::init_accel_pool, m)?)?; + m.add_function(wrap_pyfunction!(parallel::thread_count, m)?)?; m.add_function(wrap_pyfunction!(sparse::sparse_row_debug_type, m)?)?; m.add_function(wrap_pyfunction!(sparse::sparse_structure_debug_large, m)?)?; Ok(()) } - -#[pyfunction] -fn init_accel_pool(n_threads: usize) -> PyResult<()> { - debug!( - "initializing accelerator thread pool with {} threads", - n_threads - ); - ThreadPoolBuilder::new() - .num_threads(n_threads) - .build_global() - .map_err(|_| PyErr::new::("Rayon initialization error")) -} - -#[pyfunction] -fn thread_count() -> PyResult { - Ok(current_num_threads()) -} diff --git a/src/accel/parallel.rs b/src/accel/parallel.rs new file mode 100644 index 000000000..97649acac --- /dev/null +++ b/src/accel/parallel.rs @@ -0,0 +1,39 @@ +// This file is part of LensKit. +// Copyright (C) 2018-2023 Boise State University. +// Copyright (C) 2023-2025 Drexel University. +// Licensed under the MIT license, see LICENSE.md for details. +// SPDX-License-Identifier: MIT + +use log::*; +use pyo3::{exceptions::PyRuntimeError, prelude::*}; + +#[cfg(feature = "fuse-parallel")] +use rayon::iter::PanicFuse; +use rayon::{current_num_threads, iter::ParallelIterator, ThreadPoolBuilder}; + +#[pyfunction] +pub fn init_accel_pool(n_threads: usize) -> PyResult<()> { + debug!( + "initializing accelerator thread pool with {} threads", + n_threads + ); + ThreadPoolBuilder::new() + .num_threads(n_threads) + .build_global() + .map_err(|_| PyErr::new::("Rayon initialization error")) +} + +#[pyfunction] +pub fn thread_count() -> PyResult { + Ok(current_num_threads()) +} + +#[cfg(not(feature = "fuse-parallel"))] +pub fn maybe_fuse(iter: I) -> I { + iter +} + +#[cfg(feature = "fuse-parallel")] +pub fn maybe_fuse(iter: I) -> PanicFuse { + iter.panic_fuse() +} diff --git a/src/accel/slim/mod.rs b/src/accel/slim/mod.rs new file mode 100644 index 000000000..d5f8ac729 --- /dev/null +++ b/src/accel/slim/mod.rs @@ -0,0 +1,204 @@ +// This file is part of LensKit. +// Copyright (C) 2018-2023 Boise State University. +// Copyright (C) 2023-2025 Drexel University. +// Licensed under the MIT license, see LICENSE.md for details. +// SPDX-License-Identifier: MIT + +//! Sparse Linear Methods for recommendation. + +use log::*; +use pyo3::{exceptions::PyValueError, prelude::*}; +use rayon::prelude::*; + +use arrow::{ + array::{make_array, Array, ArrayData}, + pyarrow::PyArrowType, +}; + +use crate::{ + parallel::maybe_fuse, + sparse::{ArrowCSRConsumer, CSRStructure, CSR}, +}; + +const EPSILON: f64 = 1.0e-12; +// default value from Karypis code +const OPT_TOLERANCE: f64 = 1e-7; + +#[derive(Debug, Clone, Copy)] +struct SLIMOptions { + l1_reg: f64, + l2_reg: f64, + max_iters: u32, +} + +struct SLIMWorkspace<'a> { + options: SLIMOptions, + ui_matrix: &'a CSRStructure, + iu_matrix: &'a CSRStructure, + n_users: usize, + n_items: usize, +} + +/// Register the lenskit._accel.slim module +pub fn register_slim(parent: &Bound<'_, PyModule>) -> PyResult<()> { + let slim = PyModule::new(parent.py(), "slim")?; + parent.add_submodule(&slim)?; + slim.add_function(wrap_pyfunction!(train_slim, &slim)?)?; + + Ok(()) +} + +/// Learn SLIM regression weights. +/// +/// This returns the **transpose** of the weight matrix, for convenient +/// implementation. +#[pyfunction] +fn train_slim<'py>( + py: Python<'py>, + ui_matrix: PyArrowType, + iu_matrix: PyArrowType, + l1_reg: f64, + l2_reg: f64, + max_iters: u32, + progress: Bound<'py, PyAny>, +) -> PyResult>> { + let ui_matrix = make_array(ui_matrix.0); + let ui_matrix = CSRStructure::::from_arrow(ui_matrix)?; + let iu_matrix = make_array(iu_matrix.0); + let iu_matrix = CSRStructure::::from_arrow(iu_matrix)?; + + if ui_matrix.n_rows != iu_matrix.n_cols { + return Err(PyValueError::new_err("user count mismatch")); + } + if ui_matrix.n_cols != iu_matrix.n_rows { + return Err(PyValueError::new_err("item count mismatch")); + } + if ui_matrix.nnz() != iu_matrix.nnz() { + return Err(PyValueError::new_err("rating count mismatch")); + } + + let progress = if progress.is_none() { + None + } else { + Some(progress.unbind()) + }; + let options = SLIMOptions { + l1_reg, + l2_reg, + max_iters, + }; + + debug!("computing similarity rows"); + let collector = if let Some(pb) = progress { + ArrowCSRConsumer::with_progress(ui_matrix.n_cols, pb) + } else { + ArrowCSRConsumer::new(ui_matrix.n_cols) + }; + + let result = py.allow_threads(move || { + let range = 0..ui_matrix.n_cols; + let chunks = maybe_fuse(range.into_par_iter()) + .map_init( + || SLIMWorkspace::create(&ui_matrix, &iu_matrix, &options), + SLIMWorkspace::compute_column, + ) + .drive(collector); + chunks.into_iter().map(|a| a.into_data().into()).collect() + }); + + Ok(result) +} + +impl<'a> SLIMWorkspace<'a> { + fn create( + ui_matrix: &'a CSRStructure, + iu_matrix: &'a CSRStructure, + options: &SLIMOptions, + ) -> Self { + let n_items = ui_matrix.n_cols; + let n_users = ui_matrix.n_rows; + SLIMWorkspace { + options: *options, + ui_matrix, + iu_matrix, + n_users, + n_items, + } + } + + /// Train a single column of the SLIM weight matrix. + /// + /// This code was written from the papers, referencing Karypis's LIBSLIM for + /// ideas on implementation details. The relevant LIBSLIM source code + /// is at https://github.com/KarypisLab/SLIM/tree/master/src/libslim. + fn compute_column(&mut self, item: usize) -> Vec<(i32, f32)> { + // get the active users for this item + let i_users = self.iu_matrix.row_cols(item); + // since it's all 1s, the length of active entries is the squared norm + let sq_cnorm = i_users.len() as f64; + + let mut weights = vec![0.0; self.n_items]; + let mut estimates = vec![0.0; self.n_users]; + + for iter in 0..self.options.max_iters { + let mut sqdelta = 0.0; + // coordinate descent - loop over items, learn that row in the weight vector + for i in 0..self.n_items { + let old_w = weights[i]; + // subtract this item's contribution to the estimate + if old_w > 0.0 { + for c in i_users { + estimates[*c as usize] -= old_w + } + } + + // compute the update value - sum errors where user is active (so rating is 1) + let mut update = 0.0; + for u in i_users { + let u = *u as usize; + update += 1.0 - estimates[u]; + } + // convert to mean + update /= self.n_users as f64; + + // soft-threshold and adjust + let new = if update >= self.options.l1_reg { + let num = update - self.options.l1_reg; + num / (sq_cnorm - self.options.l2_reg) + } else { + 0.0 + }; + let delta = new - old_w; + sqdelta += delta * delta; + weights[i] = new; + + // update estimates + if new > 0.0 { + for c in i_users { + estimates[*c as usize] += new + } + } + } + if sqdelta <= OPT_TOLERANCE { + debug!("finished column {} after {} iters", item, iter + 1); + break; + } + } + + // sparsify weights for final result + let res: Vec<_> = weights + .into_iter() + .enumerate() + .filter_map(|(i, v)| { + if v >= EPSILON { + Some((i as i32, v as f32)) + } else { + None + } + }) + .collect(); + + // and we're done! + res + } +} diff --git a/src/accel/sparse/consumer.rs b/src/accel/sparse/consumer.rs index 78c1adeb5..19a24f7c5 100644 --- a/src/accel/sparse/consumer.rs +++ b/src/accel/sparse/consumer.rs @@ -58,6 +58,7 @@ impl ArrowCSRConsumer { val_bld: Float32Builder::new(), } } + pub(crate) fn new(dim: usize) -> Self { Self::from_state(CSRState::new(dim, None)) } diff --git a/src/lenskit/_accel/__init__.pyi b/src/lenskit/_accel/__init__.pyi index 7ef6091e7..8075e40b2 100644 --- a/src/lenskit/_accel/__init__.pyi +++ b/src/lenskit/_accel/__init__.pyi @@ -9,11 +9,12 @@ import pyarrow as pa from lenskit.data.types import NPMatrix, NPVector from lenskit.funksvd import FunkSVDTrainingData, FunkSVDTrainingParams -from . import als, data +from . import als, data, slim __all__ = [ "als", "data", + "slim", "init_accel_pool", "thread_count", "FunkSVDTrainer", diff --git a/src/lenskit/_accel/slim.pyi b/src/lenskit/_accel/slim.pyi new file mode 100644 index 000000000..a3f61358d --- /dev/null +++ b/src/lenskit/_accel/slim.pyi @@ -0,0 +1,13 @@ +from collections.abc import Sequence + +from lenskit.data.matrix import SparseRowArray +from lenskit.logging import Progress + +def train_slim( + ui_matrix: SparseRowArray, + iu_matrix: SparseRowArray, + l1_reg: float, + l2_reg: float, + max_iters: int, + progress: Progress | None, +) -> Sequence[SparseRowArray]: ... diff --git a/src/lenskit/data/matrix.py b/src/lenskit/data/matrix.py index 49ff19fd7..afee6821e 100644 --- a/src/lenskit/data/matrix.py +++ b/src/lenskit/data/matrix.py @@ -483,6 +483,10 @@ def values(self) -> pa.Array | None: else: return None + @property + def nnz(self) -> int: + return self.offsets[len(self)].as_py() + def structure(self) -> SparseRowArray: """ Get the structure of this matrix (without values). diff --git a/src/lenskit/knn/__init__.py b/src/lenskit/knn/__init__.py index 58ea96cfe..579a115bf 100644 --- a/src/lenskit/knn/__init__.py +++ b/src/lenskit/knn/__init__.py @@ -9,6 +9,14 @@ """ from .item import ItemKNNConfig, ItemKNNScorer +from .slim import SLIMConfig, SLIMScorer from .user import UserKNNConfig, UserKNNScorer -__all__ = ["ItemKNNScorer", "ItemKNNConfig", "UserKNNScorer", "UserKNNConfig"] +__all__ = [ + "ItemKNNScorer", + "ItemKNNConfig", + "UserKNNScorer", + "UserKNNConfig", + "SLIMScorer", + "SLIMConfig", +] diff --git a/src/lenskit/knn/slim.py b/src/lenskit/knn/slim.py new file mode 100644 index 000000000..c4ab23666 --- /dev/null +++ b/src/lenskit/knn/slim.py @@ -0,0 +1,116 @@ +# This file is part of LensKit. +# Copyright (C) 2018-2023 Boise State University. +# Copyright (C) 2023-2025 Drexel University. +# Licensed under the MIT license, see LICENSE.md for details. +# SPDX-License-Identifier: MIT + +""" +Sparse LInear Methods for Recommendation :cite:p:`ningSLIMSparseLinear2011`. +""" + +import warnings + +import numpy as np +import pyarrow as pa +from pydantic import BaseModel, PositiveFloat, PositiveInt +from scipy.sparse import csr_array + +from lenskit._accel import slim as _slim_accel +from lenskit.data import Dataset, ItemList, RecQuery, Vocabulary +from lenskit.data.matrix import SparseRowArray +from lenskit.diagnostics import DataWarning +from lenskit.logging import get_logger, item_progress +from lenskit.parallel.config import ensure_parallel_init +from lenskit.pipeline.components import Component +from lenskit.training import Trainable, TrainingOptions + +_log = get_logger(__name__) + + +class SLIMConfig(BaseModel): + l1_reg: PositiveFloat = 0.005 + """ + L₁ regularization strength for SLIM. + """ + l2_reg: PositiveFloat = 0.01 + """ + L₂ regularization strength for SLIM. + """ + max_iters: PositiveInt = 50 + """ + Maximum iterations per column. + """ + + +class SLIMScorer(Component, Trainable): + """ + Item scorer using Sparse LInear Methods (SLIM). SLIM was described for + recommendation by :cite:t`ningSLIMSparseLinear2011`. This implementation + closely follows the paper, with some reference to `libslim`_ for + computational details. It uses coodrinate descent with soft thresholding + to estimate the SLIM weight matrix. + + .. _libslim: https://github.com/KarypisLab/SLIM/tree/master/src/libslim + """ + + config: SLIMConfig + + weights: csr_array + items: Vocabulary + + def train(self, data: Dataset, options: TrainingOptions): + if hasattr(self, "weights") and not options.retrain: + return + + ensure_parallel_init() + ui_matrix = data.interactions().matrix().csr_structure(format="arrow") + _log.info( + "training SLIM model with %d interactions for %d items", + ui_matrix.nnz, + ui_matrix.dimension, + ) + iu_matrix = ui_matrix.transpose() + + with item_progress("SLIM vectors", ui_matrix.dimension) as pb: + weights = _slim_accel.train_slim( + ui_matrix, + iu_matrix, + self.config.l1_reg, + self.config.l2_reg, + self.config.max_iters, + pb, + ) + weights = pa.chunked_array(weights).combine_chunks() + weights = SparseRowArray.from_array(weights) + _log.info("learned %d SLIM weights", weights.nnz) + self.weights = weights = weights.to_scipy().T.tocsr() + self.items = data.items + + def __call__(self, query: RecQuery, items: ItemList) -> ItemList: + u_items = query.user_items + if u_items is None: + warnings.warn("no user history available", DataWarning) + return ItemList(items, scores=np.nan) + + if len(u_items) == 0: + _log.debug("user %s has no history", query.user_id) + return ItemList(items, scores=np.nan) + + # get user item numbers + u_inos = u_items.numbers(vocabulary=self.items, missing="negative") + u_inos = u_inos[u_inos >= 0] + + # prepare our initial matrix + x = np.zeros(len(self.items)) + x[u_inos] = 1 + + # compute the scores + all_scores = x @ self.weights + + # finalize result + scores = np.full(len(items), np.nan, np.float32) + inos = items.numbers(vocabulary=self.items, missing="negative") + mask = inos >= 0 + scores[mask] = all_scores[inos[mask]] + + return ItemList(items, scores=scores) diff --git a/tests/data/test_arrow_sparse.py b/tests/data/test_arrow_sparse.py index 0253527a0..868739ef9 100644 --- a/tests/data/test_arrow_sparse.py +++ b/tests/data/test_arrow_sparse.py @@ -41,6 +41,7 @@ def test_sparse_from_csr(csr: csr_array[Any, tuple[int, int]]): assert len(arr.offsets) == csr.shape[0] + 1 assert len(arr.indices) == csr.nnz assert len(arr.values) == csr.nnz + assert arr.nnz == csr.nnz assert arr.offsets.to_numpy()[0] == 0 assert arr.offsets.to_numpy()[-1] == csr.nnz assert np.all(arr.offsets.to_numpy() == csr.indptr) diff --git a/tests/models/test_slim.py b/tests/models/test_slim.py new file mode 100644 index 000000000..18b73c9b8 --- /dev/null +++ b/tests/models/test_slim.py @@ -0,0 +1,37 @@ +# This file is part of LensKit. +# Copyright (C) 2018-2023 Boise State University. +# Copyright (C) 2023-2025 Drexel University. +# Licensed under the MIT license, see LICENSE.md for details. +# SPDX-License-Identifier: MIT + +import pyarrow as pa + +from lenskit._accel import slim as _slim_accel +from lenskit.data import Dataset +from lenskit.data.matrix import SparseRowArray +from lenskit.knn.slim import SLIMConfig, SLIMScorer +from lenskit.logging import get_logger +from lenskit.parallel.config import ensure_parallel_init +from lenskit.testing import ScorerTests + +_log = get_logger(__name__) + + +def test_slim_trainer(ml_ds: Dataset): + "Test internal SLIM training function." + ensure_parallel_init() + ui_matrix = ml_ds.interactions().matrix().csr_structure(format="arrow") + iu_matrix = ui_matrix.transpose() + + result = _slim_accel.train_slim(ui_matrix, iu_matrix, 0.005, 0.01, 10, None) + result = pa.chunked_array(result).combine_chunks() + result = SparseRowArray.from_array(result) + assert isinstance(result, SparseRowArray) + assert result.shape == (ml_ds.item_count, ml_ds.item_count) + _log.info("received result", nnz=result.nnz) + + +class TestSLIM(ScorerTests): + component = SLIMScorer + config = SLIMConfig(max_iters=10) + expected_ndcg = (0.01, 0.2)