Skip to content

Conversation

@nil0x9
Copy link
Contributor

@nil0x9 nil0x9 commented Dec 15, 2025

Currently when one passes loss_cfg.loss_reduction other than "token" on a ascend/npu device, a Runtime Error (device mismatch) is expected in this line:

loss = (loss * loss_weight).sum()

The root cause of this error is that, in ascend npu device, cu_seq_lens tensors are required to be on cpu. In func build_batches_loss_kwargs, the devuce ofloss_weight is inherited from num_grad_tokens -> boundaries -> cu_seq_lens -- and hence the problem.

@nil0x9 nil0x9 marked this pull request as ready for review December 15, 2025 17:38
@nil0x9 nil0x9 force-pushed the linty/fix-npu-loss-weight-device-mismatch branch from 8613e2d to f1294d3 Compare December 16, 2025 15:40
@nil0x9 nil0x9 force-pushed the linty/fix-npu-loss-weight-device-mismatch branch from f1294d3 to dacb7ef Compare December 24, 2025 08:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants