diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 5e9afe15d1..459181c899 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -922,11 +922,12 @@ def __call__( if context_parallel_group is None else torch.distributed.get_world_size(context_parallel_group) ) - logit_slice_idxs = slice( - seq_start // cp_size, - (seq_start + padded_seq_lengths[seq_idx]) // cp_size, + logit_start = seq_start // cp_size + logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size + logit_length = logit_end - logit_start + next_token_logits_slice = next_token_logits.narrow( + 1, logit_start, logit_length ) - next_token_logits_slice = next_token_logits[:, logit_slice_idxs, :] loss, metrics = self.loss_fn( next_token_logits_slice,