From a71bbf00499160e226121aa63ada54405b65dbd3 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Mon, 22 Dec 2025 20:54:01 -0800 Subject: [PATCH] fix DTensor slice crash after pytorch 2.9 bump Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/loss_functions.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 5e9afe15d1..459181c899 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -922,11 +922,12 @@ def __call__( if context_parallel_group is None else torch.distributed.get_world_size(context_parallel_group) ) - logit_slice_idxs = slice( - seq_start // cp_size, - (seq_start + padded_seq_lengths[seq_idx]) // cp_size, + logit_start = seq_start // cp_size + logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size + logit_length = logit_end - logit_start + next_token_logits_slice = next_token_logits.narrow( + 1, logit_start, logit_length ) - next_token_logits_slice = next_token_logits[:, logit_slice_idxs, :] loss, metrics = self.loss_fn( next_token_logits_slice,