-
Notifications
You must be signed in to change notification settings - Fork 22
Enhance GroupedLinear with integrating AITER triton kernels #413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
…for JSON configs - Added support for using Triton kernels in GroupedLinear, allowing for optimized performance based on environment variables. - Updated the setup.py to include JSON configuration files for Triton kernels in the package data. - Added a new test case for grouped GEMM functionality in the CI pipeline. - Refactored the handling of input tensors and gradients to accommodate the new Triton kernel logic.
|
Update copyright date of modified files |
| run_default_fa 1 triton_kernels/test_cast_mxfp8.py | ||
| run_default_fa 1 triton_kernels/test_norm_common.py | ||
| run_default_fa 1 triton_kernels/test_norms.py | ||
| run_default_fa 1 triton_kernels/test_grouped_gemm.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move it two lines higher, alphabetical sort helps to find tests
| delay_wgrad_compute, | ||
| parallel_mode=None, | ||
| ): | ||
| os.environ["NVTE_USE_GROUPED_GEMM_TRITON"] = "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This env won't be cleared if the test is skipped of failed
| else: | ||
| inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) | ||
|
|
||
| if not use_grouped_gemm_triton: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make it elif
| group_sizes_list=kwargs.get("m_splits_list", []), | ||
| ) | ||
|
|
||
| grad_biases = [None] * len(m_splits) if bias is None else bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m_splits.shape[0] or len(m_splits_list)?
| package_data = {"": ["VERSION.txt"]} | ||
| package_data = { | ||
| "": ["VERSION.txt"], | ||
| "transformer_engine.pytorch.triton_kernels.gmm": ["configs/*.json"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They should be part of pytorch extension installation not TE core
| _ = general_grouped_gemm( | ||
| general_grouped_gemm_func = general_grouped_gemm_triton if use_grouped_gemm_triton else general_grouped_gemm | ||
| # Prepare m_splits for each backend | ||
| m_splits_for_kernel = m_splits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may be more straightforward to keep m_splits as-is and add mandatory parameter m_splits_tensor or m_splits_for_kernel to general_grouped_gemm_triton(), instead of swapping them here.
Description
This PR enhances the GroupedLinear module with Triton kernel support for grouped GEMM operations, providing optimized performances. The implementation includes a complete Triton-based grouped matrix multiplication (GMM) backend that can be enabled via environment variables, along with pre-tuned configurations for optimal performance.
Benchmark results:
https://github.com/ROCm/frameworks-internal/issues/13792#issuecomment-3739558113
https://github.com/ROCm/frameworks-internal/issues/13792#issuecomment-3746418683
Type of change
Changes
NVTE_USE_GROUPED_GEMM_TRITON)m_splits_tensorparameter to keep tensor data on GPU and avoid redundant CPU-GPU data transfers for improved performancetransformer_engine/pytorch/triton_kernels/gmm/from AITER including:gmm_common.py: Common utilities and helper functionsgmm_kernels.py: Core Triton kernel implementations for grouped GEMM operationsgmm_wrapper.py: High-level wrapper functions from AITERpid_preprocessing.py: Process ID preprocessing for efficient kernel schedulinggfx942-GMM.json: pre-tuned configs for gfx942 archgfx950-GMM.json: pre-tuned configs for gfx950 archgrouped_linear.pyto support Triton kernel path with proper tensor handling.triton_kernels/common.pytests/pytorch/triton_kernels/test_grouped_gemm.py(516 lines)ci/pytorch.shto include grouped GEMM tests in the CI pipelineChecklist: