-
Notifications
You must be signed in to change notification settings - Fork 22
Hotfix/fused ce triton #409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is a purpose of this file change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modify copyright date. The file is also maintained in upstream so avoid unnecessary reformattings
| rank (int): The rank of this device in the TP group. | ||
| world_size (int): The size of world involved in this distributed loss calculation. | ||
| ignore_idx (int): Tokens to be ignored for loss and gradient calculation. | ||
| ignore_idx (int): Tokens to be ignored for loss and gradient calculation. (default -100) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no default here.
|
|
||
| def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): | ||
| def cross_entropy_backward( | ||
| _input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code interferes and conflicts with upcoming IFU 2.8
wenchenvincent
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sarthak-amd As mentioned in the previous PR, could you refactor the PR as 3 commits:
- 2 commits would be cherrypicking from the upstream PRs. (NVIDIA/TransformerEngine#1879, NVIDIA/TransformerEngine#2139)
- 1 commit for the ignore_idx with a test to cover it.
This way the PR would be very clear and easy to understand.
The one of aforementioned PRs is part of IFU 2.6, i.e. part of ROCm TE already, the other is part of IFU 2.8 |
Kernel fixes:
denom = max(1, (target != ignore_idx).sum())denominstead ofn_rowsgrad_output_strideparameter, compute dynamically:1 if grad_output.numel() > 1 else 0is_cg_capturableflag for CUDA graph compatibilityTest improvements:
torch.nn.CrossEntropyLosstorch.square(loss)for non-trivial backward