From 2c78341b71c6b318f59406dee4ba8cc2897e770b Mon Sep 17 00:00:00 2001 From: miriambt Date: Tue, 15 Apr 2025 03:41:25 -0700 Subject: [PATCH 1/2] Fix dtype mismatch in TabrModel forward pass --- pytabkit/models/nn_models/tabr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytabkit/models/nn_models/tabr.py b/pytabkit/models/nn_models/tabr.py index 5fabd96..a9b023a 100644 --- a/pytabkit/models/nn_models/tabr.py +++ b/pytabkit/models/nn_models/tabr.py @@ -388,7 +388,7 @@ def forward( probs = F.softmax(similarities, dim=-1) probs = self.dropout(probs) - context_y_emb = self.label_encoder(candidate_y[context_idx][..., None]) + context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].float()) values = context_y_emb + self.T(k[:, None] - context_k) context_x = (probs[:, None] @ values).squeeze(1) x = x + context_x From 4c313594d90babde49ea12a7b405e92ecdcb2843 Mon Sep 17 00:00:00 2001 From: miriambt Date: Wed, 16 Apr 2025 06:33:43 -0700 Subject: [PATCH 2/2] Conditionally cast context_y to .long() or .float() --- pytabkit/models/nn_models/tabr.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytabkit/models/nn_models/tabr.py b/pytabkit/models/nn_models/tabr.py index a9b023a..27f056a 100644 --- a/pytabkit/models/nn_models/tabr.py +++ b/pytabkit/models/nn_models/tabr.py @@ -388,7 +388,11 @@ def forward( probs = F.softmax(similarities, dim=-1) probs = self.dropout(probs) - context_y_emb = self.label_encoder(candidate_y[context_idx][..., None].float()) + context_y = candidate_y[context_idx][..., None] + if isinstance(self.label_encoder, nn.Sequential): + context_y_emb = self.label_encoder(context_y.long()) + else: + context_y_emb = self.label_encoder(context_y.float()) values = context_y_emb + self.T(k[:, None] - context_k) context_x = (probs[:, None] @ values).squeeze(1) x = x + context_x