diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index dd6c6a3b0..c8164908e 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -3,10 +3,11 @@ # See LICENSE for license information. import random -import pytest import torch from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from utils import dtype_tols + class TestParallelCrossEntropy: @@ -19,19 +20,25 @@ def generate_infra(self, reduce_loss: bool, label_smoothing: float): label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" ) - def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): - + def generate_input( + self, + dtype: torch.dtype, + swap_dim: bool, + ignore_idx: bool, + device: torch.device = "cuda", + ): SQ = random.choice([64, 128]) batch = random.choice([1, 2]) vocab = random.choice([64000, 128000]) ignore = random.sample(range(0, SQ - 1), 5) + # Generate random data if swap_dim: - self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda() + self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (SQ, batch), device=device) else: - self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda() + self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (batch, SQ), device=device) if ignore_idx: for i in ignore: @@ -41,9 +48,14 @@ def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): else: self.tar_test[0][i] = -100 + # Make copy of data for reference implementation self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab)) self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,)) + # Enable autograd + self.input_test.requires_grad_() + self.input_ref.requires_grad_() + def one_iteration_test( self, dtype: torch.dtype, @@ -53,18 +65,20 @@ def one_iteration_test( ignore_idx: bool = False, ): + # Random data self.generate_input(dtype, swap_dim, ignore_idx) - self.input_test.requires_grad_(True) - self.input_ref.requires_grad_(True) - + # Forward pass test_loss = self.test_loss_func( self.input_test, self.tar_test, label_smoothing, reduce_loss, None ) - ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) - # Handle backward pass based on the test scenario + # Compute square to avoid trivial backward pass + test_loss = torch.square(test_loss) + ref_loss = torch.square(ref_loss) + + # Backward pass if reduce_loss: test_loss.backward() ref_loss.backward() @@ -72,16 +86,18 @@ def one_iteration_test( test_loss.sum().backward() ref_loss.sum().backward() - test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss - - if ignore_idx: - print(test_loss, ref_loss) - - # Compare gradients when backward pass was called - torch.testing.assert_close( - torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad - ) - + # Check that loss and grad input match + tols = dtype_tols(dtype) + test_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = ref_loss.reshape(test_loss.size()) + test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = ref_grad_input.reshape(test_grad_input.size()) + torch.testing.assert_close(test_loss, ref_loss, **tols) + torch.testing.assert_close(test_grad_input, ref_grad_input, **tols) + + # Reset data self.input_test = None self.input_ref = None self.tar_test = None @@ -137,4 +153,4 @@ def test_ignore_idx(self): label_smoothing=0, reduce_loss=False, ignore_idx=True, - ) + ) \ No newline at end of file diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py index 75b5de37b..0d05babb6 100644 --- a/transformer_engine/pytorch/cross_entropy.py +++ b/transformer_engine/pytorch/cross_entropy.py @@ -1,3 +1,4 @@ +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -76,4 +77,4 @@ def backward(ctx, grad_output): ) -parallel_cross_entropy = CrossEntropyFunction.apply +parallel_cross_entropy = CrossEntropyFunction.apply \ No newline at end of file diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 20e0b737d..a0431fe18 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -17,7 +17,6 @@ import triton.language as tl from torch.utils.cpp_extension import IS_HIP_EXTENSION - @triton.jit def online_softmax_kernel( X_ptr, @@ -118,7 +117,7 @@ def cross_entropy_kernel( m_d_X_y_stride: The stride of m/d/X_y tensor. rank (int): The rank of this device in the TP group. world_size (int): The size of world involved in this distributed loss calculation. - ignore_idx (int): Tokens to be ignored for loss and gradient calculation. + ignore_idx (int): Tokens to be ignored for loss and gradient calculation. (default -100) n_cols (int): The number of columns in the input tensor. n_non_ignore (int): The number of non-ignored elements in the batch. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. @@ -231,11 +230,13 @@ def cross_entropy_kernel( else: NUM_WARPS = 32 + @triton.jit def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, + grad_output_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): @@ -258,6 +259,7 @@ def element_mul_kernel( X_ptr += program_id * X_stride # Load the gradient output value + grad_output_ptr += program_id * grad_output_stride grad_output = tl.load(grad_output_ptr) # Perform the element-wise multiplication @@ -279,6 +281,8 @@ def cross_entropy_forward( B, SQ, V = _input.shape n_rows = B * SQ + valid_token_count = int((target != ignore_idx).sum().item()) + denom = max(1, valid_token_count) assert reduce(mul, list(target.size())) == (B * SQ), "Each token needs a target token ID." @@ -334,25 +338,29 @@ def cross_entropy_forward( world_size=world_size, ignore_idx=ignore_idx, n_cols=V, - n_non_ignore=n_rows, + n_non_ignore=denom, reduce_loss=reduce_loss, label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, ) - loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows) + loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / denom) return loss, _input -def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): +def cross_entropy_backward( + _input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False +): """Backward implementation of cross entropy loss kernel""" # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time - if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + # Only check torch.equal when not in CUDA graph capturable mode + if not is_cg_capturable and torch.equal( + grad_output, torch.tensor(1.0, device=grad_output.device) + ): pass - else: B, SQ, V = _input.shape n_rows = B * SQ @@ -362,9 +370,10 @@ def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): _input, _input.stride(-2), grad_output, + 1 if grad_output.numel() > 1 else 0, V, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, ) - return _input + return _input \ No newline at end of file