From a2d3084c52e7ae88c6a20e4aaa82267d1eb93e62 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 21 Jan 2026 09:55:08 -0500 Subject: [PATCH] batch inference calls in LightGCN --- src/lenskit/graphs/lightgcn.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/lenskit/graphs/lightgcn.py b/src/lenskit/graphs/lightgcn.py index 1a814381b..849d950e6 100644 --- a/src/lenskit/graphs/lightgcn.py +++ b/src/lenskit/graphs/lightgcn.py @@ -31,6 +31,8 @@ _log = logging.get_logger(__name__) +INF_BATCH_SIZE = 8192 + class LightGCNConfig(EmbeddingSizeMixin, BaseModel): """ @@ -146,18 +148,25 @@ def __call__(self, query: QueryInput, items: ItemList) -> ItemList: # look up the item columns in the embedding matrix i_cols = items.numbers(vocabulary=self.items, missing="negative", format="torch") - i_cols = i_cols.to(self._edges.device, dtype=self._edges.dtype) + i_cols = i_cols.to(self._edges.device, dtype=self._edges.dtype, non_blocking=True) # unknown items will have column -1 - limit to the # ones we know, and remember which item IDs those are scorable_mask = i_cols.ge(0) i_cols = i_cols.masked_select(scorable_mask) + n = len(i_cols) # set up the edge tensor - u_tensor = torch.from_numpy(np.repeat(np.array([u_row + self._user_base]), len(i_cols))) - u_tensor = u_tensor.to(self._edges.device, dtype=self._edges.dtype) + u_tensor = torch.from_numpy(np.repeat(np.array([u_row + self._user_base]), n)) + u_tensor = u_tensor.to(self._edges.device, dtype=self._edges.dtype, non_blocking=True) edges = torch.stack([u_tensor, i_cols]) - scores = self.model(self._edges, edges) + assert edges.shape == (2, n) + + scores = torch.zeros(n, device=edges.device) + # we work in batches to reduce inference memory usage + for bs in range(0, len(i_cols), INF_BATCH_SIZE): + be = min(bs + INF_BATCH_SIZE, n) + scores[bs:be] = self.model(self._edges, edges[:, bs:be]) # initialize output score array, fill with missing full_scores = torch.full((len(items),), np.nan, dtype=torch.float32, device=scores.device) @@ -177,6 +186,10 @@ def create_trainer(self, data, options): class LightGCNTrainer(ModelTrainer): + """ + Model trainer for :class:`LightGCNScorer`. + """ + scorer: LightGCNScorer data: Dataset options: TrainingOptions