Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Dec 10, 2025

Description

Adds nvte_grouped_gemm API using cuBLASLt grouped matmul for batched GEMM on tensors with varying shapes. A GPU kernel (setup_grouped_gemm_kernel) converts NVTEGroupedTensor format (contiguous buffer + offsets) to cuBLAS requirements (pointer arrays + per-matrix M/N/K).

New API

void nvte_grouped_gemm(int transa, int transb, 
                       const NVTETensor alpha, 
                       const NVTEGroupedTensor A,
                       const NVTEGroupedTensor B, 
                       const NVTETensor beta, 
                       const NVTEGroupedTensor C,
                       NVTEGroupedTensor D, 
                       NVTETensor workspace_setup, 
                       NVTETensor workspace_cublas,
                       NVTEMatmulConfig config, 
                       cudaStream_t stream, 
                       const int64_t *avg_m,
                       const int64_t *avg_n, 
                       const int64_t *avg_k);

Computes D = alpha * op(A) @ op(B) + beta * C for groups of matrices with potentially different shapes.

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

  • GPU setup kernel computing pointers/dims from grouped tensor metadata
  • FP8 support with scale_inv handling and TN layout selection on Hopper
  • GroupedGemmSetupWorkspace struct for cuBLAS workspace layout
  • Tests in test_grouped_gemm.cu comparing against nvte_multi_tensor_gemm (FP8/BF16, various shapes and transpose layouts)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL changed the title [common] Add support for cublasLt GEMM for GroupedTensor [common] Add support for cuBLASLt GEMM for GroupedTensor Dec 10, 2025
pre-commit-ci bot and others added 3 commits December 10, 2025 14:32
- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM
- Fix random padding in tests to ensure 16-byte alignment for all dtypes
- Reorder GroupedGemmSetupWorkspace members for natural alignment
- Remove debug prints

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@ptrendx ptrendx added the MoE label Dec 10, 2025
@ptrendx ptrendx linked an issue Dec 10, 2025 that may be closed by this pull request
pggPL and others added 2 commits December 10, 2025 22:34
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Dec 10, 2025

/te-ci L0

@pggPL pggPL marked this pull request as ready for review December 10, 2025 21:43
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 10, 2025

Greptile Summary

  • Adds nvte_grouped_gemm API for batched matrix multiplication on tensors with varying shapes using cuBLASLt grouped matmul, enabling efficient computation for transformers with variable sequence lengths
  • Implements GPU kernel setup_grouped_gemm_kernel to convert between NVTEGroupedTensor format (contiguous buffer + offsets) and cuBLAS requirements (pointer arrays + per-matrix M/N/K dimensions)
  • Includes comprehensive FP8 support with scale_inv handling, TN layout selection for Hopper GPUs, and requires cuBLAS 13.1+ with Blackwell architecture (SM100) or newer

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu New file implementing the grouped GEMM functionality with complex operand selection logic, GPU kernel for format conversion, and comprehensive input validation
tests/cpp/operator/test_grouped_gemm.cu New comprehensive test file validating grouped GEMM against existing multi-tensor GEMM for FP8/BF16 data types and various matrix configurations

Confidence score: 4/5

  • This PR is generally safe to merge with moderate review attention needed for the complex GEMM implementation
  • Score reflects well-structured code with comprehensive testing but complexity in GPU kernel logic and cuBLAS integration requires careful validation
  • Pay close attention to cublaslt_grouped_gemm.cu for the GPU kernel implementation and operand selection logic which handles critical format conversions and FP8 scaling

Sequence Diagram

sequenceDiagram
    participant User
    participant API as nvte_grouped_gemm
    participant Validator as validate_grouped_gemm_inputs
    participant Selector as select_grouped_operand
    participant Kernel as setup_grouped_gemm_kernel
    participant cuBLAS as cublasLtMatmul
    participant HandleMgr as cublasHandleManager

    User->>API: "Call nvte_grouped_gemm(transa, transb, alpha, A, B, beta, C, D, workspace_setup, workspace_cublas, stream)"
    
    API->>Validator: "validate_grouped_gemm_inputs(A, B, C, D, alpha, beta)"
    Validator-->>API: "Validation complete"
    
    API->>Selector: "select_grouped_operand(A, transa, is_A=true)"
    Selector-->>API: "A_sel (data pointer, dtype, transpose flag)"
    
    API->>Selector: "select_grouped_operand(B, transb, is_A=false)"
    Selector-->>API: "B_sel (data pointer, dtype, transpose flag)"
    
    API->>API: "GroupedGemmSetupWorkspace::from_buffers(workspace_ptr, num_tensors)"
    
    API->>Kernel: "launch_grouped_gemm_setup(workspace, A_sel, B_sel, C, D, alpha, beta, num_tensors, stream)"
    Note over Kernel: "Populates pointer arrays and M/N/K dimensions for each matrix in group"
    Kernel-->>API: "Setup arrays populated"
    
    API->>HandleMgr: "GetHandle()"
    HandleMgr-->>API: "cublasLtHandle_t"
    
    API->>API: "init_matrix_layouts(descA, descB, descC, descD, workspace, A_sel, B_sel, D, num_tensors)"
    API->>API: "init_matmul_desc(matmulDesc, op_A, op_B)"
    API->>API: "set_fp8_scale_pointers(matmulDesc, A_sel, B_sel)"
    API->>API: "select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC, descD, avg_m, avg_n, avg_k)"
    
    API->>cuBLAS: "cublasLtMatmul(handle, matmulDesc, alpha_ptrs, A_ptrs, B_ptrs, beta_ptrs, C_ptrs, D_ptrs, algo, workspace, stream)"
    Note over cuBLAS: "D = alpha * op(A) @ op(B) + beta * C for each matrix group"
    cuBLAS-->>API: "Grouped GEMM complete"
    
    API-->>User: "Return"
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. tests/cpp/operator/test_grouped_gemm.cu, line 75 (link)

    logic: missing columnwise_data in move assignment

  2. tests/cpp/operator/test_grouped_gemm.cu, line 336-337 (link)

    logic: kSameFirst and kSameLast test cases use identical shape tuples. Should kSameFirst and kSameLast have different shape patterns to properly test the respective scenarios?

  3. tests/cpp/operator/test_grouped_gemm.cu, line 367-378 (link)

    logic: missing case for InputCase::kFP8Delayed

  4. transformer_engine/common/include/transformer_engine/gemm.h, line 265-266 (link)

    style: The avg_m, avg_n, avg_k parameters are not documented in the function comment

    What do these average dimension parameters represent and how should they be computed?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from timmoon10 December 10, 2025 22:35
pggPL and others added 2 commits December 11, 2025 11:56
- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers
- Simplify select_grouped_operand by removing dead code branches
- Add GroupedOperandSelection.tensor field to avoid passing tensor separately
- Extract set_fp8_scale_pointers and init_matrix_layouts helpers
- Add safety check for FP8 on Hopper column-wise fallback
- Support NULL C tensor when beta=0 (uses D as placeholder)
- Remove unused get_scale_inv() from test
- Add use_null_c test parameter and test case
- Fix documentation: alpha/beta are single element tensors only

Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/common/include/transformer_engine/gemm.h, line 266 (link)

    syntax: Documentation incorrectly states alpha[i] and beta[i]. The implementation uses a single alpha and beta value for all matrices in the group (batch stride is 1 in init_matmul_desc at cublaslt_gemm.cu:1404), not per-matrix scaling.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

pggPL and others added 3 commits December 11, 2025 12:16
- Change alpha/beta from single values to per-matrix arrays
- Validate alpha/beta have exactly num_tensors elements
- Update kernel to index alpha_ptr[idx] and beta_ptr[idx]
- Move alpha/beta validation to validate_grouped_gemm_inputs
- Update tests to use per-matrix alpha/beta arrays
- Update documentation

Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Dec 11, 2025

/te-ci

greptile-apps[bot]

This comment was marked as resolved.

Comment on lines 273 to 274
NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m,
const int64_t *avg_n, const int64_t *avg_k);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The average sizes seem like advanced configs that would be better to leave out of the top-level API. Can we move them inside NVTEMatmulConfig?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was suggestion of @ptrendx - customer may potentially want to use them if they know something more about the shapes. For example if there are multiple tensors of with k dimension D and one tensor with K dimension equal to 1, then it is potentially true that telling cublas that avg dim = D will result in better performance.

Copy link
Collaborator

@timmoon10 timmoon10 Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should definitely expose these options, but it would be better to put them in NVTEMatmulConfig rather than the top-level function signature. If you look at nvte_cublas_gemm:

void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream);

You see advanced options like pre_gelu_out and use_split_accumulator that are only needed for specialized cases. It's even worse for use_split_accumulator, since that's only relevant for FP8 on Hopper and now it's sitting uselessly in the API forever. Compare with the v2 API, which is much closer to the original GEMM API:
void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A,
const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D,
NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream);

The idea of NVTEMatmulConfig is to hold these advanced, kernel-specific options in a way where we can add or deprecate them easily without breaking API changes.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/common/include/transformer_engine/gemm.h, line 256 (link)

    syntax: Documentation mentions non-existent config parameter - this parameter is not in the function signature on line 276-280

  2. tests/cpp/operator/test_grouped_gemm.cu, line 141 (link)

    logic: Alignment calculation appears incorrect - multiplying 16 bytes by 8 bits gives 128 bits, not the intended 16-byte alignment

6 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Dec 19, 2025

/te-ci L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. transformer_engine/common/include/transformer_engine/gemm.h, line 260-279 (link)

    syntax: Missing config parameter in implementation

    The function signature includes an NVTEMatmulConfig config parameter in the declaration (line 276), but the actual implementation in cublaslt_grouped_gemm.cu:498 doesn't accept this parameter. The call site in tests (line 449-462 of test_grouped_gemm.cu) also omits it, passing 0 instead.

    Either add the parameter to the implementation or remove it from the declaration.

  2. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 592 (link)

    logic: cuBLAS version mismatch in error message

    Error message mentions "cuBLAS 13.2+" but the compile-time guard checks for CUBLAS_VERSION >= 130100 (line 29), which is cuBLAS 13.1. Also, header comment references cuBLAS 13.1 vs 13.2 inconsistently.

  3. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 472-473 (link)

    logic: Potential null pointer dereference when C is NULL

    When C is NULL and use_null_c=true, C->data.dptr and C->dtype() will cause a segfault. The code sets inputC = outputD when C is NULL (line 525), but this happens after launch_grouped_gemm_setup is called, where C is still NULL.

  4. tests/cpp/operator/test_grouped_gemm.cu, line 95-102 (link)

    style: Workspace size calculation doesn't match implementation

    Test calculates 4 * ptr_bytes + 3 * int_bytes + 2 * ptr_bytes (6 pointer arrays total), but the implementation in GroupedGemmSetupWorkspace::from_buffers expects exactly 6 pointer arrays + 3 int arrays. The calculation is correct but the comment formatting makes it unclear. Consider: 6 * ptr_bytes + 3 * int_bytes.

7 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/common/include/transformer_engine/gemm.h, line 237-268 (link)

    syntax: Version inconsistency: note section mentions cuBLAS13.1+ but requirements section specifies13.2+. Need to align these versions consistently.

    Which cuBLAS version is actually required - 13.1 or 13.2?

  2. tests/cpp/operator/test_grouped_gemm.cu, line 341-344 (link)

    logic: Transpose logic appears inverted - for transa=true, A should be transposed so input shape should be (K,M) to produce effective (M,K) for GEMM. Is the tensor shape logic correct for transpose operations? Typically transa=true means the input A matrix needs to be transposed during the operation.

7 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. tests/cpp/operator/test_grouped_gemm.cu, line 485 (link)

    syntax: incorrect version check - should be 130100 not 130200

    The API requires cuBLAS 13.1+ (version 130100), but this conditional check uses 130200. This mismatch means tests will be skipped even on cuBLAS 13.1.

8 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 22, 2025

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@pggPL
Copy link
Collaborator Author

pggPL commented Dec 22, 2025

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GroupedGemm: FP8 per-tensor via cuBLAS

3 participants