Skip to content

Conversation

@sudhu2k
Copy link
Contributor

@sudhu2k sudhu2k commented Jan 15, 2026

Description

This PR enhances the GroupedLinear module with Triton kernel support for grouped GEMM operations, providing optimized performances. The implementation includes a complete Triton-based grouped matrix multiplication (GMM) backend that can be enabled via environment variables, along with pre-tuned configurations for optimal performance.

  • Added support for using Triton kernels in GroupedLinear, allowing for optimized performance based on environment variables.
  • Updated the setup.py to include JSON configuration files for Triton kernels in the package data.
  • Added a new test case for grouped GEMM functionality in the CI pipeline.
  • Refactored the handling of input tensors and gradients to accommodate the new Triton kernel logic.

Benchmark results:

https://github.com/ROCm/frameworks-internal/issues/13792#issuecomment-3739558113
https://github.com/ROCm/frameworks-internal/issues/13792#issuecomment-3746418683

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added Triton kernel support for GroupedLinear: Implemented a complete Triton-based grouped GEMM backend with support for dynamic kernel selection based on environment variables (NVTE_USE_GROUPED_GEMM_TRITON)
  • Added optional m_splits_tensor parameter to keep tensor data on GPU and avoid redundant CPU-GPU data transfers for improved performance
  • New GMM (Grouped Matrix Multiplication) module from AITER: Added comprehensive Triton kernel implementation in transformer_engine/pytorch/triton_kernels/gmm/ from AITER including:
    • gmm_common.py: Common utilities and helper functions
    • gmm_kernels.py: Core Triton kernel implementations for grouped GEMM operations
    • gmm_wrapper.py: High-level wrapper functions from AITER
    • pid_preprocessing.py: Process ID preprocessing for efficient kernel scheduling
  • Pre-tuned configurations: Added JSON configuration files for AMD GPU architectures:
    • gfx942-GMM.json: pre-tuned configs for gfx942 arch
    • gfx950-GMM.json: pre-tuned configs for gfx950 arch
  • Updated setup.py: Modified package data to include JSON configuration files for Triton kernels
  • Enhanced GroupedLinear module: Refactored grouped_linear.py to support Triton kernel path with proper tensor handling.
  • Added grouped_gemm.py wrapper: Created high-level interface in TE for grouped GEMM operations
  • Extended common utilities: Added Triton kernel support flags in triton_kernels/common.py
  • New test suite: Added comprehensive test cases (From AITER) in tests/pytorch/triton_kernels/test_grouped_gemm.py (516 lines)
  • CI integration: Updated ci/pytorch.sh to include grouped GEMM tests in the CI pipeline

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

sugovind added 2 commits January 15, 2026 16:23
…for JSON configs

- Added support for using Triton kernels in GroupedLinear, allowing for optimized performance based on environment variables.
- Updated the setup.py to include JSON configuration files for Triton kernels in the package data.
- Added a new test case for grouped GEMM functionality in the CI pipeline.
- Refactored the handling of input tensors and gradients to accommodate the new Triton kernel logic.
@ipanfilo
Copy link
Collaborator

Update copyright date of modified files

run_default_fa 1 triton_kernels/test_cast_mxfp8.py
run_default_fa 1 triton_kernels/test_norm_common.py
run_default_fa 1 triton_kernels/test_norms.py
run_default_fa 1 triton_kernels/test_grouped_gemm.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move it two lines higher, alphabetical sort helps to find tests

delay_wgrad_compute,
parallel_mode=None,
):
os.environ["NVTE_USE_GROUPED_GEMM_TRITON"] = "1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This env won't be cleared if the test is skipped of failed

else:
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)

if not use_grouped_gemm_triton:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it elif

group_sizes_list=kwargs.get("m_splits_list", []),
)

grad_biases = [None] * len(m_splits) if bias is None else bias
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m_splits.shape[0] or len(m_splits_list)?

package_data = {"": ["VERSION.txt"]}
package_data = {
"": ["VERSION.txt"],
"transformer_engine.pytorch.triton_kernels.gmm": ["configs/*.json"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They should be part of pytorch extension installation not TE core

_ = general_grouped_gemm(
general_grouped_gemm_func = general_grouped_gemm_triton if use_grouped_gemm_triton else general_grouped_gemm
# Prepare m_splits for each backend
m_splits_for_kernel = m_splits
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be more straightforward to keep m_splits as-is and add mandatory parameter m_splits_tensor or m_splits_for_kernel to general_grouped_gemm_triton(), instead of swapping them here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants