diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 66400ffd7..f12a95d79 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -53,7 +53,11 @@ jobs: || github.actor == 'lhb8125' || github.actor == 'kunlunl' || github.actor == 'pstjohn' - || github.actor == 'mk-61' + || github.actor == 'vcherepanov-nv' + || github.actor == 'tdophung' + || github.actor == 'vthumbe1503' + || github.actor == 'janekb04' + || github.actor == 'shengfangd' ) steps: - name: Check if comment is issued by authorized person diff --git a/.gitmodules b/.gitmodules index 45691e71a..2fcd51fcb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -20,3 +20,6 @@ [submodule "examples/pytorch/nanogpt"] path = examples/pytorch/nanogpt url = https://github.com/floraamd/nanoGPTwTE.git +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 91b7532f3..deda80e53 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 91b7532f3386768bba4f444ee7672b497f34da8a +Subproject commit deda80e5372d50e925d7bf4f76c5db779be3fbd5 diff --git a/3rdparty/cutlass b/3rdparty/cutlass new file mode 160000 index 000000000..57e3cfb47 --- /dev/null +++ b/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit 57e3cfb47a2d9e0d46eb6335c3dc411498efa198 diff --git a/README.rst b/README.rst index daf3cbc45..66d1b0b3e 100644 --- a/README.rst +++ b/README.rst @@ -526,15 +526,15 @@ For example to use the NGC PyTorch container interactively, .. code-block:: bash - docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.04-py3 + docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.08-py3 For example to use the NGC JAX container interactively, .. code-block:: bash - docker run --gpus all -it --rm nvcr.io/nvidia/jax:25.04-py3 + docker run --gpus all -it --rm nvcr.io/nvidia/jax:25.08-py3 -Where 25.04 (corresponding to April 2025 release) is the container version. +Where 25.08 (corresponding to August 2025 release) is the container version. **Benefits of using NGC containers:** diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py index dafafdff4..1df16cc01 100644 --- a/benchmarks/attention/benchmark_attention.py +++ b/benchmarks/attention/benchmark_attention.py @@ -9,11 +9,11 @@ import torch import nvtx import transformer_engine -from tests.pytorch.fused_attn.test_fused_attn import ( +from tests.pytorch.utils import ( ModelConfig, - _get_attention_backends, - _run_dot_product_attention, + get_available_attention_backends, ) +from tests.pytorch.attention.test_attention import _run_dot_product_attention pd.set_option("display.precision", 4) @@ -197,7 +197,7 @@ def main(): ) for model in model_configs.keys(): config = model_configs[model] - available_backends, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/benchmarks/attention/benchmark_attention_rocm.py b/benchmarks/attention/benchmark_attention_rocm.py index b126fb022..0c37696ac 100644 --- a/benchmarks/attention/benchmark_attention_rocm.py +++ b/benchmarks/attention/benchmark_attention_rocm.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -13,17 +13,17 @@ import transformer_engine from transformer_engine_torch import NVTE_Fused_Attn_Backend -# Add test_fused_attn to the sys path +# Add TE repo root to the sys path tests_path = os.path.abspath( - os.path.join(os.path.dirname(__file__), "../../tests/pytorch/fused_attn") + os.path.join(os.path.dirname(__file__), "../../") ) sys.path.append(tests_path) -from test_fused_attn import ( +from tests.pytorch.utils import ( ModelConfig, - _get_attention_backends, - _run_dot_product_attention, + get_available_attention_backends, ) +from tests.pytorch.attention.test_attention import _run_dot_product_attention pd.set_option("display.precision", 4) @@ -46,12 +46,12 @@ is_training = True model_configs = { - # test: b, h, hg, d, sq, skv, p, mask, bias - "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq - "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask - "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias - "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA - "test_4": ModelConfig(2, 128, 8, 128, 8192, 8192, 0.0, "causal_bottom_right", "no_bias") + # b, sq, h, dqk + "test_0": ModelConfig(2, 512, 16, 64), # short seq + "test_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), # longer seq, mask + "test_2": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"), # bias + "test_3": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), # GQA + "test_4": ModelConfig(2, 8192, 128, 128, num_gqa_groups=8, attn_mask_type="causal_bottom_right") } # DataFrame indices and columns for results @@ -303,7 +303,7 @@ def sanity_checks( } for model, cfg in model_configs.items(): - avail, _, fused_bes = _get_attention_backends( + avail, _, fused_bes = get_available_attention_backends( cfg, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -364,7 +364,7 @@ def main(args): # Benchmarking starts.. for model in model_configs.keys(): config = model_configs[model] - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 0dbee212d..44f1c8967 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -247,7 +247,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): num_gemms_list = [8] if args.profile: - mkns = [(4096, 4096, 4096)] + mkns = [(4096 * 8, 4096, 4096)] # in profile mode, only run one recipe specified in args.recipe assert args.recipe != "all", ( "In profile mode, only one recipe can be specified, please specify the recipe as" diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 2a45a8a5c..81006d78c 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.6.0.dev0 +2.8.0.dev0 diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 4056a5fa7..bb084293f 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -27,20 +27,7 @@ def install_requirements() -> List[str]: """Install dependencies for TE/PyTorch extensions.""" - reqs = ["einops"] - if not rocm_build(): - reqs.append( - "nvdlfw-inspect @" - " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect" - ) - reqs.extend( - [ - "torch>=2.1", - "onnx", - "onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871", - ] - ) - return reqs + return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"] def test_requirements() -> List[str]: diff --git a/build_tools/utils.py b/build_tools/utils.py index 739e353e4..e3c5b6be8 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -16,7 +16,7 @@ import subprocess import sys from pathlib import Path -from importlib.metadata import version +from importlib.metadata import version as get_version from subprocess import CalledProcessError from typing import List, Optional, Tuple, Union @@ -340,7 +340,7 @@ def cuda_version() -> Tuple[int, ...]: return tuple(int(v) for v in version) try: - version_str = version("nvidia-cuda-runtime-cu12") + version_str = get_version("nvidia-cuda-runtime-cu12") version_tuple = tuple(int(part) for part in version_str.split(".") if part.isdigit()) return version_tuple except importlib.metadata.PackageNotFoundError: diff --git a/ci/jax.sh b/ci/jax.sh index 6229a7de3..b72a72e52 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -1,5 +1,5 @@ #!/bin/sh -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. @@ -54,6 +54,7 @@ run_default_fa_lbl() { run_test_config() { echo ==== Run with Fused attention backend: $_fus_attn ==== + export NVTE_JAX_UNITTEST_LEVEL=L0 # this env variable controls parameters set for some tests run_default_fa 1 test_custom_call_compute.py run_default_fa 1 test_functions.py run 1 test_fused_attn.py @@ -61,7 +62,6 @@ run_test_config() { run_default_fa 1 test_helper.py run_default_fa 1 test_layer.py #it effectevly always uses unfused attention run_default_fa 1 test_sanity_import.py - run_default_fa 1 test_sharding.py run_default_fa 1 test_softmax.py } @@ -76,8 +76,10 @@ run_test_config_mgpu() { if [ $_fus_attn = $_DEFAULT_FUSED_ATTN ]; then _dfa_level=2 + export NVTE_JAX_UNITTEST_LEVEL=L1 else _dfa_level=3 + export NVTE_JAX_UNITTEST_LEVEL=L2 fi run $_dfa_level test_distributed_fused_attn.py $_timeout_args run_default_fa 3 test_distributed_layernorm.py diff --git a/ci/pytorch.sh b/ci/pytorch.sh index a6ad620fe..be150485f 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -1,5 +1,5 @@ #!/bin/sh -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. @@ -65,7 +65,9 @@ run_test_config(){ run_default_fa 1 test_recipe.py run 1 test_sanity.py run_default_fa 1 test_sanity_import.py - run_default_fa 1 fused_attn/test_fused_attn.py # Backend selection is controlled by the test + run_default_fa 1 attention/test_attention.py # Backend selection is controlled by the test + run_default_fa 1 attention/test_cp_utils.py + run_default_fa 1 attention/test_kv_cache.py run_default_fa 1 triton_kernels/test_cast.py run_default_fa 1 triton_kernels/test_cast_mxfp8.py run_default_fa 1 triton_kernels/test_norm_common.py @@ -73,15 +75,13 @@ run_test_config(){ NVTE_TEST_TRITON_AUTOTUNE=1 run_default_fa_lbl "autotune" 3 triton_kernels/test_norms.py run_default_fa 1 test_parallel_cross_entropy.py NVTE_USE_DEQUANTIZE_TRITON=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 NVTE_USE_LAYERNORM_TRITON=1 run_default_fa_lbl "triton" 3 test_numerics.py - NVTE_USE_RMSNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_fusible_ops.py + NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_fusible_ops.py NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "triton" 1 test_float8_current_scaling_exact.py - NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_numerics.py - NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_fusible_ops.py - NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_numerics.py - NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_fusible_ops.py - NVTE_USE_ATOMIC_AMAX=0 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_numerics.py - NVTE_USE_ATOMIC_AMAX=0 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_fusible_ops.py - NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 triton_kernels/test_cast.py + NVTE_USE_ATOMIC_AMAX=1 run_default_fa_lbl "amax" 3 test_numerics.py + NVTE_USE_ATOMIC_AMAX=1 run_default_fa_lbl "amax" 3 test_fusible_ops.py + NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "amax+triton" 3 test_numerics.py + NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "amax+triton" 3 test_fusible_ops.py + NVTE_USE_ATOMIC_AMAX=1 run_default_fa_lbl "amax" 3 triton_kernels/test_cast.py } run_test_config_mgpu(){ @@ -93,8 +93,8 @@ run_test_config_mgpu(){ run_default_fa 2 distributed/test_numerics.py run_default_fa 1 distributed/test_torch_fsdp2.py run_default_fa 2 distributed/test_torch_fsdp2_fp8.py - run_default_fa_lbl "flash" 3 fused_attn/test_fused_attn_with_cp.py -k "with_flash" - run_default_fa_lbl "fused" 2 fused_attn/test_fused_attn_with_cp.py -k "with_fused" + run_default_fa_lbl "flash" 3 attention/test_attention_with_cp.py -k "with_flash" + run_default_fa_lbl "fused" 2 attention/test_attention_with_cp.py -k "with_fused" } run_benchmark() { diff --git a/docs/api/jax.rst b/docs/api/jax.rst index d72af37ec..1af5cd1d0 100644 --- a/docs/api/jax.rst +++ b/docs/api/jax.rst @@ -19,6 +19,10 @@ Variables are available in `transformer_engine.jax.sharding`. * JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded. +Checkpointing +------------------------------------ +When using checkpointing with Transformer Engine JAX, please be aware of the checkpointing policy being applied to your model. Any JAX checkpointing policy using `dot`, such as `jax.checkpoint_policies.dots_with_no_batch_dims`, may not work with GEMMs provided by Transformer Engine as they do not always use the `jax.lax.dot_general` primitive. Instead, you can use `transformer_engine.jax.checkpoint_policies.dots_and_te_gemms_with_no_batch_dims` or similar policies that are designed to work with Transformer Engine's GEMMs and `jax.lax.dot_general` GEMMs. You may also use any JAX policies that do not filter by primitive, such as `jax.checkpoint_policies.save_only_these_names` or `jax.checkpoint_policies.everything_saveable`. + Modules ------------------------------------ .. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 0b1f1fab9..fcfa20cbd 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -49,7 +49,7 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.moe_permute -.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs +.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs .. autoapifunction:: transformer_engine.pytorch.moe_unpermute @@ -63,3 +63,6 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.destroy_ub .. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index + +.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode + :members: FP8, NONE \ No newline at end of file diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index bc2b95057..a7b86dad3 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -21,7 +21,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea There are 4 things one needs to do to use Transformer Engine debug features: 1. Create a configuration YAML file to configure the desired features. -2. Import, and initialize the `Nvidia-DL-Framework-Inspect `_ tool, which is installed as the dependency of the Transformer Engine. +2. Import, initialize, and install the `Nvidia-DL-Framework-Inspect `_ tool. 3. One can pass ``name="..."`` when creating TE layers to easier identify layer names. If this is not provided, names will be inferred automatically. 4. Invoke ``debug_api.step()`` at the end of one forward-backward pass. @@ -141,7 +141,7 @@ Adjusting Python file In the modified code above, the following changes were made: 1. Added an import for ``nvdlfw_inspect.api``. -2. Initialized the Nvidia-DL-Framework-Inspect by calling ``debug_api.initialize()`` with appropriate configuration, specifying the path to the config file, feature directories, and log directory. +2. Initialized the Nvidia-DL-Framework-Inspect by calling ``debug_api.initialize()`` with appropriate configuration, specifying the path to the config file, feature directories, and log directory. The directory with Transformer Engine features is located `here `_. The full parameters description could be found :doc:`here <3_api_debug_setup>`. 3. Added ``debug_api.step()`` after each of the forward-backward pass. Inspecting the logs @@ -238,4 +238,4 @@ Let's run training and open TensorBoard by ``tensorboard --logdir=./tensorboard_ .. figure:: ./img/tensorboard.png :align: center - Fig 2: TensorBoard with plotted stats. \ No newline at end of file + Fig 2: TensorBoard with plotted stats. diff --git a/docs/debug/3_api_te_calls.rst b/docs/debug/3_api_te_calls.rst index eb66c8ff2..1435d41d7 100644 --- a/docs/debug/3_api_te_calls.rst +++ b/docs/debug/3_api_te_calls.rst @@ -12,14 +12,7 @@ Let's look deeper into how Nvidia-DL-Framework-Inspect with Transformer Engine w Fig 1: Example of Nvidia-DL-Framework-Inspect affecting training script with 1 Linear Layer. For tensors mentioned in ``config.yaml``, behavior of ``modify_tensor_enabled()`` and ``modify_tensor()`` calls are substituted with definitions from the feature class. Other calls return default values - in fact they do nothing. -In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed. The order of these calls is illustrated in the image below. - -.. figure:: ./img/api_calls2.svg - :align: center - - Fig 2: The calls to Nvidia-DL-Framework-Inspect done for Transformer Engine. There are 2 types of calls: GEMM calls and routing calls. - - +In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed. There are 2 categories of API calls, each is used for different purposes: - GEMM calls - invoked during every GEMM, used to process or quantize tensors and collect information about them, @@ -32,14 +25,15 @@ if fusions happen. An important remark is that if no feature is used for the lay .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor -.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor - -.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor_enabled .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.fp8_gemm_enabled +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize + .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_enabled .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize_enabled diff --git a/docs/debug/img/api_calls2.svg b/docs/debug/img/api_calls2.svg deleted file mode 100644 index 5df72fc2e..000000000 --- a/docs/debug/img/api_calls2.svg +++ /dev/null @@ -1 +0,0 @@ -Tensor Ainspect_tensorfp8 castmodify_tensorinspect_tensor_postquantizeGEMMinspect_tensormodify_tensorinspect_tensor_enabledinspect_tensor_postquantize_enabledfp8_gemm_enabledmodify_tensor_enabledTensor Binspect_tensorfp8 castmodify_tensorinspect_tensor_postquantizeRouting callsGEMM calls \ No newline at end of file diff --git a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py index e9eec14d9..97f1bcd7e 100644 --- a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py +++ b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py @@ -5,7 +5,7 @@ import os import torch from typing import Tuple -from tests.pytorch.fused_attn.test_fused_attn import ModelConfig +from tests.pytorch.utils import ModelConfig from transformer_engine.pytorch.attention import DotProductAttention # Initialize RNG state diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 53a5eede7..61a6ad949 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -375,7 +375,7 @@ "\n", "Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n", "\n", - "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts." + "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts." ] }, { @@ -390,14 +390,14 @@ "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n", "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n", "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n", - "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | No | Yes (`bshd`,`thd`) | Yes |\n", + "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes | Yes (`bshd`,`thd`) | Yes |\n", "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", - "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", - "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", - "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", - "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)" + "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n", + "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n", + "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n", + "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py)" ] }, { @@ -458,7 +458,7 @@ " \n", "\n", "\n", - "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", + "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", "\n", "
\n", "Note\n", @@ -548,7 +548,7 @@ "id": "dda4a589", "metadata": {}, "source": [ - "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n", + "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py).\n", "\n", "### 3.3 Attention Bias\n", "\n", @@ -594,7 +594,7 @@ "\n", "The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n", "\n", - "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)." + "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)." ] }, { @@ -612,7 +612,7 @@ "\n", "- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n", "\n", - "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`." + "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`." ] } ], diff --git a/docs/examples/attention/example_attention.py b/docs/examples/attention/example_attention.py index 2c32e8b5f..cf650265b 100644 --- a/docs/examples/attention/example_attention.py +++ b/docs/examples/attention/example_attention.py @@ -9,11 +9,11 @@ import torch import nvtx import transformer_engine -from tests.pytorch.fused_attn.test_fused_attn import ( +from tests.pytorch.utils import ( ModelConfig, - _get_attention_backends, - _run_dot_product_attention, + get_available_attention_backends, ) +from tests.pytorch.attention.test_attention import _run_dot_product_attention # data type dtype = torch.bfloat16 @@ -90,7 +90,7 @@ def main(): models = ["test_0"] for model in models: config = model_configs[model] - available_backends, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/docs/examples/onnx/onnx_export.ipynb b/docs/examples/onnx/onnx_export.ipynb new file mode 100644 index 000000000..26ac71188 --- /dev/null +++ b/docs/examples/onnx/onnx_export.ipynb @@ -0,0 +1,345 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Export to ONNX and inference using TensorRT\n", + "\n", + "
\n", + "\n", + "Note:\n", + "\n", + "Currently, export to ONNX is supported only for high precision, FP8 delayed scaling, FP8 current scaling and MXFP8.\n", + "\n", + "
\n", + "\n", + "Transformer Engine (TE) is a library designed primarily for training DL models in low precision. It is not specifically optimized for inference tasks, so other dedicated solutions should be used. NVIDIA provides several [inference tools](https://www.nvidia.com/en-us/solutions/ai/inference/) that enhance the entire inference pipeline. Two prominent NVIDIA inference SDKs are [TensorRT](https://github.com/NVIDIA/TensorRT) and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM).\n", + "\n", + "This tutorial illustrates how one can export a PyTorch model to ONNX format and subsequently perform inference with TensorRT. This approach is particularly beneficial if model integrates Transformer Engine layers within more complex architectures. It's important to highlight that for Transformer-based large language models (LLMs), TensorRT-LLM could provide a more optimized inference experience. However, the ONNX-to-TensorRT approach described here may be more suitable for other models, such as diffusion-based architectures or vision transformers.\n", + "\n", + "#### Creating models with TE\n", + "\n", + "Let's begin by defining a simple model composed of layers both from Transformer Engine and standard PyTorch:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch \n", + "import torch.nn as nn\n", + "import transformer_engine as te\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "# batch size, sequence length, hidden dimension\n", + "B, S, H = 256, 512, 256\n", + "\n", + "class Model(torch.nn.Module):\n", + " def __init__(self, hidden_dim=H, num_non_te_layers=16, num_te_layers=4, num_te_heads=4):\n", + " super(Model, self).__init__()\n", + " self.non_te_part = nn.Sequential(\n", + " *[nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.GELU()) for _ in range(num_non_te_layers)]\n", + " )\n", + " self.te_part = nn.Sequential(\n", + " *[te.pytorch.TransformerLayer(hidden_dim, hidden_dim, num_te_heads) for _ in range(num_te_layers)]\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.non_te_part(x)\n", + " return self.te_part(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's run some simple inference benchmarks:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average inference time FP32: 0.065 ms\n", + "Average inference time FP8: 0.062 ms\n" + ] + } + ], + "source": [ + "from utils import _measure_time\n", + "\n", + "model = Model().eval().cuda()\n", + "inps = (torch.randn([S, B, H], device=\"cuda\"),)\n", + "def _inference(fp8_enabled):\n", + " with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8_enabled):\n", + " model(*inps)\n", + "\n", + "te_fp32_time = _measure_time(lambda: _inference(fp8_enabled=False))\n", + "te_fp8_time = _measure_time(lambda: _inference(fp8_enabled=True))\n", + "\n", + "print(f\"Average inference time FP32: {te_fp32_time} ms\")\n", + "print(f\"Average inference time FP8: {te_fp8_time} ms\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Exporting the TE Model to ONNX Format\n", + "\n", + "PyTorch developed a new [ONNX exporter](https://pytorch.org/docs/stable/onnx.html) built on TorchDynamo and plans to phase out the existing TorchScript exporter. As this feature is currently in active development, we recommend running this process with the latest PyTorch version.\n", + "\n", + "\n", + "To export a Transformer Engine model into ONNX format, follow these steps:\n", + "\n", + "- Conduct warm-up run within autocast using the recipe intended for export.\n", + "- Encapsulate your export-related code within `te.onnx_export`, ensuring warm-up runs remain outside this wrapper.\n", + "- Use the PyTorch Dynamo ONNX exporter by invoking: `torch.onnx.export(..., dynamo=True)`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exporting model_fp8.onnx\n", + "[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`...\n", + "[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`... ✅\n", + "[torch.onnx] Run decomposition...\n", + "[torch.onnx] Run decomposition... ✅\n", + "[torch.onnx] Translate the graph into ONNX...\n", + "[torch.onnx] Translate the graph into ONNX... ✅\n", + "Applied 12 of general pattern rewrite rules.\n", + "Exporting model_fp32.onnx\n", + "[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`...\n", + "[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`... ✅\n", + "[torch.onnx] Run decomposition...\n", + "[torch.onnx] Run decomposition... ✅\n", + "[torch.onnx] Translate the graph into ONNX...\n", + "[torch.onnx] Translate the graph into ONNX... ✅\n", + "Applied 12 of general pattern rewrite rules.\n" + ] + } + ], + "source": [ + "from transformer_engine.pytorch.export import te_translation_table\n", + "\n", + "def export(model, fname, inputs, fp8=True):\n", + " with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8):\n", + " # ! IMPORTANT !\n", + " # Transformer Engine models must have warm-up run\n", + " # before export. FP8 recipe during warm-up should \n", + " # match the recipe used during export.\n", + " model(*inputs)\n", + " \n", + " # Only dynamo=True mode is supported;\n", + " # dynamo=False is deprecated and unsupported.\n", + " #\n", + " # te_translation_table contains necessary ONNX translations\n", + " # for FP8 quantize/dequantize operators.\n", + " print(f\"Exporting {fname}\")\n", + " with te.pytorch.onnx_export(enabled=True):\n", + " torch.onnx.export(\n", + " model,\n", + " inputs,\n", + " fname,\n", + " output_names=[\"output\"],\n", + " dynamo=True,\n", + " custom_translation_table=te_translation_table\n", + " )\n", + "\n", + "# Example usage:\n", + "export(model, \"model_fp8.onnx\", inps, fp8=True)\n", + "export(model, \"model_fp32.onnx\", inps, fp8=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Inference with TensorRT\n", + "\n", + "TensorRT is a high-performance deep learning inference optimizer and runtime developed by NVIDIA. It enables optimized deployment of neural network models by maximizing inference throughput and reducing latency on NVIDIA GPUs. TensorRT performs various optimization techniques, including layer fusion, precision calibration, kernel tuning, and memory optimization. \n", + "For detailed information and documentation, refer to the official [TensorRT documentation](https://developer.nvidia.com/tensorrt).\n", + "\n", + "When using TensorRT, ONNX model must first be compiled into a TensorRT engine. This compilation step involves converting the ONNX model into an optimized representation tailored specifically to the target GPU platform. The compiled engine file can then be loaded into applications for rapid and efficient inference execution." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "!trtexec --onnx=model_fp32.onnx --saveEngine=model_fp32.engine > output_fp32.log 2>&1\n", + "!trtexec --onnx=model_fp8.onnx --saveEngine=model_fp8.engine > output_fp8.log 2>&1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's run the benchmarks for inference:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average inference time without TRT (FP32 for all layers): 0.065 ms\n", + "Average inference time without TRT (FP8 for TE layers, FP32 for non-TE layers): 0.062 ms, speedup = 1.05x\n", + "Average inference time with TRT (FP32 for all layers): 0.0500 ms, speedup = 1.30x\n", + "Average inference time with TRT (FP8 for TE layers, FP32 for non-TE layers): 0.0470 ms, speedup = 1.38x\n" + ] + } + ], + "source": [ + "import tensorrt as trt\n", + "\n", + "# Output tensor is allocated - TRT needs static memory address.\n", + "output_tensor = torch.empty_like(model(*inps))\n", + "\n", + "# Loads TRT engine from file.\n", + "def load_engine(engine_file_path):\n", + " logger = trt.Logger(trt.Logger.WARNING)\n", + " runtime = trt.Runtime(logger)\n", + " \n", + " with open(engine_file_path, \"rb\") as f:\n", + " engine_data = f.read()\n", + " \n", + " engine = runtime.deserialize_cuda_engine(engine_data)\n", + " return engine\n", + "\n", + "def benchmark_inference(model_name):\n", + " engine = load_engine(model_name)\n", + " context = engine.create_execution_context()\n", + " stream = torch.cuda.Stream()\n", + " \n", + " # TRT need static input and output addresses.\n", + " # Here they are set.\n", + " for i in range(len(inps)):\n", + " context.set_tensor_address(engine.get_tensor_name(i), inps[i].data_ptr()) \n", + " context.set_tensor_address(\"output\", output_tensor.data_ptr())\n", + " \n", + " def _inference():\n", + " # The data is loaded from static input addresses\n", + " # and output is written to static output address.\n", + " context.execute_async_v3(stream_handle=stream.cuda_stream)\n", + " stream.synchronize()\n", + " \n", + " return _measure_time(_inference)\n", + "\n", + "\n", + "trt_fp8_time = benchmark_inference(\"model_fp8.engine\")\n", + "trt_fp32_time = benchmark_inference(\"model_fp32.engine\")\n", + "\n", + "print(f\"Average inference time without TRT (FP32 for all layers): {te_fp32_time} ms\")\n", + "print(f\"Average inference time without TRT (FP8 for TE layers, FP32 for non-TE layers): {te_fp8_time} ms, speedup = {te_fp32_time/te_fp8_time:.2f}x\")\n", + "print(f\"Average inference time with TRT (FP32 for all layers): {trt_fp32_time:.4f} ms, speedup = {te_fp32_time/trt_fp32_time:.2f}x\")\n", + "print(f\"Average inference time with TRT (FP8 for TE layers, FP32 for non-TE layers): {trt_fp8_time:.4f} ms, speedup = {te_fp32_time/trt_fp8_time:.2f}x\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

\n", + "\n", + "\n", + "| Run | Inference Time (ms) | Speedup |\n", + "| ----------------------------------| ------------------- | ------------------- |\n", + "| PyTorch + TE | 0.065 | 1.00x |\n", + "| PyTorch + TE (FP8 for TE layers) | 0.062 | 1.05x |\n", + "| TRT | 0.0500 | 1.30x |\n", + "| TRT (FP8 for TE layers) | 0.047 | 1.38x |\n", + "\n", + "Note that this example highlights how TensorRT can speed up models composed of both TE and non-TE layers.\n", + "If a larger part of the model's layers were implemented with TE, the benefits of using FP8 for inference could be greater.\n", + "\n", + "

\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We clearly observe performance improvements when using FP8 and the TensorRT inference engine. These improvements may become even more significant with more complex models, as TensorRT could potentially identify additional optimization opportunities.\n", + "\n", + "#### Appendix: Low Precision Operators in ONNX and TensorRT\n", + "\n", + "The ONNX standard does not currently support all precision types provided by the Transformer Engine. All available ONNX operators are listed on [this website](https://onnx.ai/onnx/operators/). Consequently, TensorRT and the Transformer Engine utilize certain specialized low-precision operators, detailed below.\n", + "\n", + "**TRT_FP8_QUANTIZE**\n", + "\n", + "- **Name**: TRT_FP8_QUANTIZE\n", + "- **Domain**: trt\n", + "- **Inputs**:\n", + " - `x`: float32 tensor\n", + " - `scale`: float32 scalar\n", + "- **Outputs**:\n", + " - `y`: int8 tensor\n", + "\n", + "Produces an int8 tensor that represents the binary encoding of FP8 values.\n", + "\n", + "**TRT_FP8_DEQUANTIZE**\n", + "\n", + "- **Name**: TRT_FP8_DEQUANTIZE\n", + "- **Domain**: trt\n", + "- **Inputs**:\n", + " - `x`: int8 tensor\n", + " - `scale`: float32 scalar\n", + "- **Outputs**:\n", + " - `y`: float32 tensor\n", + "\n", + "Converts FP8-encoded int8 tensor data back into float32 precision.\n", + "\n", + "
\n", + "\n", + "Note:\n", + "\n", + "Since standard ONNX operators do not support certain input and output precision types, a workaround is employed: tensors are dequantized to higher precision (float32) before input into these operators or quantized to lower precision after processing. TensorRT recognizes such quantize-dequantize patterns and replaces them with optimized operations. More details are available in [this section](https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/work-quantized-types.html#tensorrt-processing-of-q-dq-networks) of the TensorRT documentation.\n", + "\n", + "
" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/examples/onnx/utils.py b/docs/examples/onnx/utils.py new file mode 100644 index 000000000..7acf2ffc6 --- /dev/null +++ b/docs/examples/onnx/utils.py @@ -0,0 +1,25 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Utility functions for ONNX export. +""" + +import time +import torch + + +def _measure_time(f): + + time_taken = [] + num_iterations = 10 + f() # warm-up + + for _ in range(num_iterations): + start_time = time.time() + f() + torch.cuda.synchronize() + end_time = time.time() + time_taken.append(end_time - start_time) + return round(sum(time_taken) / num_iterations, 3) diff --git a/docs/examples/te_gemma/media/calibration.svg b/docs/examples/te_gemma/media/calibration.svg new file mode 100644 index 000000000..16e1a4314 --- /dev/null +++ b/docs/examples/te_gemma/media/calibration.svg @@ -0,0 +1,620 @@ + + + + + + + + + + + FP8 with initial scaling factors + + + High + precision + weight + + Initial + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + + + + Calibration + + + High + precision + weight + + FP8 scaling + factors + + High + precision + input + + High + precision + GEMM + + + + FP8 with calibrated scaling factors + + + High + precision + weight + + Calibrated + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + + diff --git a/docs/examples/te_gemma/media/calibration_1_half.svg b/docs/examples/te_gemma/media/calibration_1_half.svg new file mode 100755 index 000000000..478604d41 --- /dev/null +++ b/docs/examples/te_gemma/media/calibration_1_half.svg @@ -0,0 +1,415 @@ + + + + + + + + + + + + + High + precision + weight + + Initial + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + + + + + High + precision + weight + + FP8 scaling + factors + + High + precision + input + + High + precision + GEMM + + + + + FP8 with initial scaling factors + Calibration + + diff --git a/docs/examples/te_gemma/media/calibration_2_half.svg b/docs/examples/te_gemma/media/calibration_2_half.svg new file mode 100644 index 000000000..439f4c16f --- /dev/null +++ b/docs/examples/te_gemma/media/calibration_2_half.svg @@ -0,0 +1,401 @@ + + + + + + + + + + + + Calibration + + + High + precision + weight + + FP8 scaling + factors + + High + precision + input + + High + precision + GEMM + + + + FP8 with calibrated scaling factors + + + High + precision + weight + + Calibrated + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + diff --git a/docs/examples/te_gemma/media/fp8_model_init.svg b/docs/examples/te_gemma/media/fp8_model_init.svg new file mode 100644 index 000000000..57af23dc3 --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init.svg @@ -0,0 +1,500 @@ + + + + + + + + + + FP32/BF16 + + FP8 + FP8 with fp8_model_init() + + + FP8 + weight + + FP8 + GEMM + + + + + High + precision + weight + + High + precision + input + + High + precision + GEMM + + + + + High + precision + weight + + FP8 + Weight + + + FP8 + Input + + + FP8 + GEMM + + + + + + High + precision + input + + + FP8 + Input + + + + + High + precision + input + + diff --git a/docs/examples/te_gemma/media/fp8_model_init_1_half.svg b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg new file mode 100644 index 000000000..d86751e07 --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg @@ -0,0 +1,358 @@ + + + + + + + + + + + FP32/BF16 + + + + High + precision + weight + + High + precision + input + + High + precision + GEMM + + + FP8 + + + High + precision + weight + + FP8 + Weight + + + FP8 + Input + + + FP8 + GEMM + + + + + + High + precision + input + + diff --git a/docs/examples/te_gemma/media/fp8_model_init_2_half.svg b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg new file mode 100644 index 000000000..c3e4146ba --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg @@ -0,0 +1,371 @@ + + + + + + + + + + + FP8 + FP8 with fp8_model_init() + + + FP8 + weight + + FP8 + GEMM + + + + High + precision + weight + + FP8 + Weight + + + FP8 + Input + + + FP8 + GEMM + + + + + + High + precision + input + + + FP8 + Input + + + + + High + precision + input + diff --git a/docs/examples/te_gemma/media/generation_animation.gif b/docs/examples/te_gemma/media/generation_animation.gif new file mode 100644 index 000000000..25150cb9b Binary files /dev/null and b/docs/examples/te_gemma/media/generation_animation.gif differ diff --git a/docs/examples/te_gemma/media/graphs.svg b/docs/examples/te_gemma/media/graphs.svg new file mode 100644 index 000000000..fc77387af --- /dev/null +++ b/docs/examples/te_gemma/media/graphs.svg @@ -0,0 +1,232 @@ + + + + + + + + + + + Without CUDA Graphs + With CUDA Graphs + + Launch 1 + + Kernel 1 + + Launch 2 + + Kernel 2 + + Launch 3 + + Kernel 3 + + + Launch Graph 1 + + Kernel 1 + + Kernel 2 + + Kernel 3 + + + diff --git a/docs/examples/te_gemma/media/transformer_cuda_graphed.png b/docs/examples/te_gemma/media/transformer_cuda_graphed.png new file mode 100644 index 000000000..cf22822ba Binary files /dev/null and b/docs/examples/te_gemma/media/transformer_cuda_graphed.png differ diff --git a/docs/examples/te_gemma/requirements.txt b/docs/examples/te_gemma/requirements.txt new file mode 100755 index 000000000..a4eaeea43 --- /dev/null +++ b/docs/examples/te_gemma/requirements.txt @@ -0,0 +1,4 @@ +transformers==4.55.0 +accelerate==1.10.0 +datasets==4.0.0 +sentencepiece==0.2.1 diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py new file mode 100755 index 000000000..6285fea1a --- /dev/null +++ b/docs/examples/te_gemma/te_gemma.py @@ -0,0 +1,703 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import contextmanager + +from typing import Optional +from functools import partial +from collections import OrderedDict + +import torch +from torch.amp import autocast + +import transformer_engine as te +from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding +from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.pytorch.fp8 import get_default_fp8_recipe +import transformers +from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel + +import torch.nn.functional as F + +""" +Top level description of the classes used in the tutorial from this file. +---------------------------------------------------------------------- + +HuggingFace Gemma Model implementation hierarchy: +---------------------------------- +GemmaDecoderLayer: +├── self_attn: +│ ├── norm: (nn.LayerNorm) +│ ├── qkv_proj: (nn.Linear) +│ ├── attention: (SDPA, FlashAttention, etc.) +│ └── o_proj: (nn.Linear) +├── ffn: +│ ├── norm: (nn.LayerNorm) +│ ├── gate_proj: (nn.Linear) +│ ├── up_proj: (nn.Linear) +│ └── down_proj: (nn.Linear) + +GemmaModel: +├── embed_tokens : Token embedding layer +├── layers : GemmaDecoderLayer × N +├── norm : GemmaRMSNorm +└── rotary_emb : GemmaRotaryEmbedding + +GemmaForCausalLM: +├── model : instance of GemmaModel +├── lm_head : (nn.Linear) hidden states to vocabulary logits for generation +└── generate : generate method (input prompt -> GemmaForCausalLM -> next tokens) + +How `generate()` works in HF's GemmaForCausalLM: + 1. prefill (input prompt -> model -> lm_head -> logits -> next token) + 2. loop until max_new_tokens: + - next token -> model -> lm_head -> logits -> next token + 3. return all tokens + +NOTE: Notice how "prefill" and "loop until next tokens" are just part of the `generate()` method. + This is a common pattern in HF models. + + +TransformerEngine's Gemma Model Hierarchy: +---------------------------------------- +HF's `GemmaDecoderLayer` is monkey-patched with `TEGemmaDecoderLayer` before `GemmaForCausalLM` is initialized. This way, +while the model is downloaded from HuggingFace and most of the code runs from HF's `GemmaForCausalLM`, the underlying +blocks of "transformer layer" are actually from TransformerEngine. + +TEGemmaDecoderLayer (inherits from te.TransformerLayer): +├── te.MultiHeadAttention: +│ ├── linear_qkv: (te.LayerNormLinear) +│ ├── attention: (te.DotProductAttention) +│ └── out_proj: (te.LayerNormLinear) +├── te.LayerNormMLP: +│ ├── fc1: (te.LayerNormLinear) +│ ├── fc2: (te.Linear) +│ └── activation: (te.GeGLU) + +To be able to use `model.generate()`, an entry point is needed. `TEGemmaForCausalLM` is the entry point which +subclasses HF's `GemmaForCausalLM` and adds a few attributes and methods. + +TEGemmaForCausalLM (inherits from HF's GemmaForCausalLM) +├─ model : inherited from HF's GemmaForCausalLM but with monkey-patched TEGemmaDecoderLayer × N +├─ lm_head : directly inherited from HF's GemmaForCausalLM +├─ te_rope_emb : RotaryPositionEmbedding (reusing the same for all layers for CUDA graphs compatibility) +├─ hidden_states_buffer : shape [b, max_ctx, h] (static) +├─ generation_buffer : shape [b, 1, h] (view of `hidden_states_buffer`) (static) +├─ inference_params : TransformerEngine KV cache +├─ model_context_phase : GemmaModelWrapper → uses (model, lm_head, inference_params) for full-sequence prefill +├─ model_generation_phase : GemmaGenerationWrapper → uses (model, lm_head, inference_params) for single-token decode +└─ generate : generate method (input prompt -> TEGemmaForCausalLM -> next tokens) + +Notice how "prefill" and "loop until next tokens" are specialized to wrapper subroutines - "model_context_phase" and +"model_generation_phase" respectively which makes it easier to use CUDA Graphs. Just one more abstraction is needed: + +TEGemmaForCausalLMCudaGraphs (inherits from TEGemmaForCausalLM) +├─ model : unchanged (HF's GemmaModel with monkey-patched TEGemmaDecoderLayer × N) +├─ lm_head : unchanged +├─ hidden_states_buffer : unchanged +├─ generation_buffer : unchanged +├─ inference_params : unchanged +├─ record : utility function to record the graphed callable +├─ model_context_phase : GraphedCallable(for Context/prefill) replaced by `record` +├─ model_generation_phase : GraphedCallable(for Generation) replaced by `record` +└─ generate : unchanged + +How `generate()` works in TEGemmaForCausalLM/TEGemmaForCausalLMCudaGraphs: + 1. model_context_phase (input prompt -> model -> lm_head -> logits -> next token) + 2. model_generation_phase: + - loop until max_new_tokens: + - next token -> model -> lm_head -> logits -> next token + 3. return all tokens + +NOTE: In the tutorial, `record` is called when initializing the model. + +Additional notes and clarifications +----------------------------------- +- Wrappers, not submodules: + `model_context_phase` and `model_generation_phase` are convenience wrappers over the same + `model` (GemmaModel) and `lm_head`. They own no parameters; they standardize buffer usage, + masks (context uses "padding_causal", generation uses "padding"), rotary embeddings, and + KV-cache (`InferenceParams`) flow for TE-optimized inference. + +- Buffer relationship: + `hidden_states_buffer` has shape [b, max_ctx, h]. `generation_buffer` is a contiguous view + of size [b, 1, h] carved from its start to avoid non-contiguous indexing. Generation updates + `generation_buffer` in-place with next-token embeddings. + +- Padding policy: + Inputs may arrive left-padded (HF-style). Before TE execution, padding is shifted to the end + to match TE attention mask expectations and to keep shapes contiguous for capture/replay. + +- CUDA Graphs specifics: + `record()` captures two separate callables (context/prefill and generation) with fixed shapes and + stable pointers, then replaces the wrappers with these GraphedCallables. Under graphs, the + functional behavior is identical; only allocation/pointer churn and CPU overhead are removed. +""" + + +class TEGemmaDecoderLayer(te.pytorch.TransformerLayer): + """ + Wrapper class over TE's `TransformerLayer`. This makes the wrapper very + similar to HF's `GemmaDecoderLayer` and easier to replace it in the code. + + Args: + config: GemmaConfig + args: positional args (for compatibility with `GemmaDecoderLayer`) + kwargs: keyword args (for compatibility with `GemmaDecoderLayer`) + """ + + def __init__(self, config: GemmaConfig, layer_idx: int, *args, **kwargs): + + self.gemma_config = config + + super().__init__( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=config.fuse_qkv_params, + normalization="RMSNorm", + activation="geglu", + attn_input_format="bshd", + num_gqa_groups=config.num_key_value_heads, + kv_channels=self.gemma_config.head_dim, + layer_number=( + layer_idx + 1 + ), # Layer numbers in TE starts from 1, not 0 like in the HF. + zero_centered_gamma=True, + ) + + def forward(self, *args, **kwargs): # We need to additionally pass positional encoding. + + # filter out HF specific args + keys_to_remove = [ + "position_ids", + "past_key_value", + "output_attentions", + "use_cache", + "cache_position", + ] + for key in keys_to_remove: + kwargs.pop(key, None) + + rope_emb = kwargs.pop("rope_emb", None) + + # Return tuple to be compatible with HF. + return (super().forward(*args, rotary_pos_emb=rope_emb, **kwargs),) + + +class GemmaModelWrapper(torch.nn.Module): + """ + Encapsulates the HuggingFace GemmaModel class as a wrapper whose + forward pass is compatible with CUDA Graphs. + """ + + def __init__( + self, + model: GemmaModel, + dtype: torch.dtype, + lm_head: torch.nn.Module, + ): + super().__init__() + self.model = model + self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) + self.lm_head = lm_head + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor = None, + attn_mask_type: str = "arbitrary", + rope_emb: torch.Tensor = None, + ): + with torch.no_grad(): + # static operation - for CUDA graphs + hidden_states.data[:] = hidden_states.data[:] * self.normalizer + + for i, decoder_layer in enumerate(self.model.layers): + hidden_states.data[:] = decoder_layer( + hidden_states, + attention_mask=attention_mask, + self_attn_mask_type=self.mask if attn_mask_type is None else attn_mask_type, + inference_params=self.inference_params, + rope_emb=rope_emb, + )[ + 0 + ] # static copy - for CUDA graphs + + hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs + logits = self.lm_head(hidden_states) + + # This is not needed for generation but is needed for training + # or finetuning. + if self.training: + logits = logits.float() + + return logits + + +class GemmaGenerationWrapper(torch.nn.Module): + """ + Gets token embeddings for a batch of single tokens, runs forward pass, and + returns the batch ofnext tokens. Also compatible with CUDA graphs. Not a + subclass of `GemmaModel` since the model layers are simply reused here. + """ + + def __init__( + self, + model: GemmaModel, + lm_head: torch.nn.Module, + dtype: torch.dtype, + ): + super().__init__() + self.model = model + self.gemma_layers = GemmaModelWrapper(model, dtype, lm_head) + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + self.gemma_layers.set_inference_params(inference_params) + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor = None, + attn_mask_type: str = "arbitrary", + rope_emb: torch.Tensor = None, + ): + logits = self.gemma_layers( + hidden_states, attention_mask=mask, attn_mask_type=attn_mask_type, rope_emb=rope_emb + ) + + assert logits.shape[0] == hidden_states.shape[0] # b + assert logits.shape[1] == hidden_states.shape[1] # seq_len + + # Fetch the logits for the last token + logits = logits[:, -1, :] + next_tokens = torch.argmax(logits, dim=1) + + # static copy for CUDA graphs + hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1)) + + return next_tokens + + +@contextmanager +def replace_decoder(te_decoder_cls): + """ + Monkey-patches `GemmaDecoderLayer` with the custom `TEGemmaDecoderLayer` + class. + """ + original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls + try: + yield + finally: + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls + + +class TEGemmaForCausalLM(GemmaForCausalLM): + """ + Causal LM created with `GemmaModel`. The underlying `GemmaDecoderLayer` + class is monkey-patched with `TEGemmaDecoderLayer` class before + initializing the causal LM with `GemmaForCausalLM`. + + Args: + config: Gemma model config that HF uses to initialize the model. + """ + + def __init__(self, config: GemmaConfig): + + dtype = torch.bfloat16 + with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): + super().__init__(config) + + self.config = config + self.to(dtype).cuda() + self.hidden_size = config.hidden_size + + self._model_context_phase = GemmaModelWrapper(self.model, dtype, self.lm_head) + + self._model_generation_phase = GemmaGenerationWrapper( + lm_head=self.lm_head, + model=self.model, + dtype=dtype, + ) + + if self.config.fp8: + self.fp8_recipe = get_default_fp8_recipe() + + # Rotary position embedding remains the same for all the layers and so + # created here. This makes it compatible with CUDA Graphs too. + self.te_rope_emb = RotaryPositionEmbedding(self.config.head_dim)( + max_seq_len=self.config.max_position_embeddings + ).cuda() + + @staticmethod + def _padding_to_end(inputs, lengths, max_seq_len=None): + """ + Gets the tensor with sequence padded from the beginning and + updates it inplace to be padded from its end. + + Parameters + ---------- + inputs : Tensor, tensor with shape [b, s] containing token numbers. + It's padded from the beggining. + lengths: Tensor, tensor with shape [s] with lengths of the sequences. + + """ + max_seq_len = torch.max(lengths) if max_seq_len is None else max_seq_len + batch_size, max_seq_len = inputs.shape + new_input_ids = inputs.clone() + for i in range(batch_size): + new_input_ids[i, : lengths[i]] = inputs[i, (max_seq_len - lengths[i]) : max_seq_len] + new_input_ids[i, lengths[i] :] = inputs[i, 0 : (max_seq_len - lengths[i])] + + # Trim the inputs to no extra padding i.e. fix the max seq len to + # the longest sequence in the batch + actual_max_seq_len = max_seq_len + inputs.data = new_input_ids[:, :actual_max_seq_len] + + def _create_or_fetch_hidden_states_buffer(self, input_ids: torch.Tensor): + """ + Returns a tensor of shape [b, s, hd] where `b` is the batch size, + `s` is the sequence length, and `hd` is the hidden size. + + This function is overriden in TEGemmaForCausalLMCudaGraphs. + """ + + tensor = torch.empty( + (input_ids.shape[0], input_ids.shape[1], self.hidden_size), + device="cuda", + dtype=torch.float32, + ) + return tensor + + def _create_or_fetch_inference_params(self, *args, **kwargs): + """ + Creates an InferenceParams object. + + This function is overriden in TEGemmaForCausalLMCudaGraphs. + """ + + infer_params = InferenceParams(*args, **kwargs) + return infer_params + + def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None): + """ + Returns a tensor of shape [b, 1, hd] where `b` is the batch size, + `hd` is the hidden size. + + The buffer for generation is some part (beginning) of hidden states buffer. + This function returns pointer to it and also copies there data if provided. + """ + # hidden_states_buffer has shape [b, s, hd] + # generation_buffer will have shape [b, 1, hd] + # Notice that `hidden_states_buffer[:, 0, :].unsqueeze(1)` will return + # uncontiguous buffer, which we want to avoid. + output = hidden_states_buffer.view(-1)[ + : hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2] + ] + if data_to_copy is not None: + output.copy_(data_to_copy.reshape(-1)) + generation_buffer = output.view( + (hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2]) + ) + return generation_buffer + + def setup_and_run_context_phase( + self, input_ids: torch.Tensor, inference_params: InferenceParams + ): + """ + Runs the context or prefill phase of the model. + + This function is overriden in TEGemmaForCausalLMCudaGraphs. + """ + + hidden_states = self._create_or_fetch_hidden_states_buffer(input_ids) + hidden_states.copy_(self.model.embed_tokens(input_ids)) + + # Update offsets before every forward pass (including context/prefill + # phase) to make cache work properly. + lengths = input_ids.ne(0).sum(dim=1) + inference_params.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))) + + logits = self._model_context_phase( + hidden_states, + attention_mask=None, + attn_mask_type="padding_causal", + rope_emb=self.te_rope_emb, + ) + + logits = logits[torch.arange(logits.size(0)), lengths - 1, :] + next_tokens = torch.argmax(logits, dim=1) + + # `self.hidden_states` has shape [b, s, hd]. + # Return hidden state for the last token - output has shape [b, 1, hd]. + hidden_states = self._get_generation_buffer( + hidden_states, self.model.embed_tokens(next_tokens) + ) + return hidden_states, next_tokens + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + pad_token_id: int = 0, + max_new_tokens: int = 0, + *args, + **kwargs, + ): + """ + Generates next tokens auto-regressively for a batch of input tokens. + """ + self.eval() + + # Both autocasts are needed: FP8 for operations that can run in lower + # precision and BF16 for those that cannot. + with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast( + enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None + ): + lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() + # If padding is at the beginning, then shift it to the end + TEGemmaForCausalLM._padding_to_end( + input_ids, + lengths, + max_seq_len=( + self.config.cuda_graphs_static_max_context_len + if self.config.generation_cuda_graphs + else None + ), + ) + + batch_size = input_ids.shape[0] + # For benchmark generation run, this is being set explicitly. + max_input_sequence_len = self.config.max_seq_length + + # InferenceParams is a cache, where keys and values of previous + # tokens are stored. Moreover it stores the current running lengths + # of the sequences in the current batch. + # A helper function is used to create the inference params object + # because this `generate` method is common for TEGemmaForCausalLM + # and TEGemmaForCausalLMCudaGraphs. In case of CudaGraphs, this + # function is overriden to simply return the inference params object + # that is already created in TEGemmaForCausalLMCudaGraphs' + # constructor. + inference_params = self._create_or_fetch_inference_params( + max_batch_size=batch_size, + max_sequence_length=max_input_sequence_len, + num_heads_kv=self.config.num_key_value_heads, + head_dim_v=self.config.head_dim, + head_dim_k=self.config.head_dim, + dtype=torch.bfloat16, + is_paged=self.config.is_paged, + page_size=16, + total_num_pages=batch_size * max_input_sequence_len // 16, + ) + + # Set the inference params for both the context/prefill phase and + # generation phase objects. + self._model_context_phase.set_inference_params(inference_params) + self._model_generation_phase.set_inference_params(inference_params) + + # Context/prefill phase. + hidden_states, next_tokens = self.setup_and_run_context_phase( + input_ids, inference_params + ) + + # Generation phase. + lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + output_tokens = [next_tokens] + + for _ in range(max_new_tokens): + next_tokens = self._model_generation_phase( + hidden_states, + mask=None, + attn_mask_type="padding", + rope_emb=self.te_rope_emb, + ) + + # Increase sequence offsets by one because we generated one token + # for every sequence. + lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + + # `next_tokens` is a static output tensor, so we need to clone + # it because it gets changed every iteration. + output_tokens.append(next_tokens.clone()) + + result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) + return result + + def forward(self, *args, **kwargs): + """ + Forward pass for the model. This is used in calibration step when + forward pass is needed to generate FP8 calibration data. + """ + + self._model_context_phase.set_inference_params(None) + hidden_states = self.model.embed_tokens(kwargs["input_ids"]) + logits = self._model_context_phase( + hidden_states, + attention_mask=( + kwargs["input_ids"] == 0 + ), # Hardcoded, this only applies to bshd/sbhd layouts. + attn_mask_type="padding_causal", + ) + return logits + + +class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM): + """ + TEGemmaForCausalLMCudaGraphs is a wrapper over the class TEGemmaForCausalLM + and uses CUDA Graphs to speed up the generation process. We need to make one + trade-off - batch_size, max_seq_len and max_context_seq_len need to + be static. It is necessary to run generation without changing the pointer + to the variables that are recorded in the graph. + """ + + def __init__(self, config: GemmaConfig): + super().__init__(config) + + self.config = config + + # Preparation of the static buffer to hold the hidden states that are + # passed from one layer to the next. + self.hidden_states_buffer = torch.empty( + ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + self.config.hidden_size, + ) + ).cuda() + + # This is in fact part of the buffer for hidden_states. Refer to the + # `_get_generation_buffer` function for more details. + self.generation_buffer = self._get_generation_buffer( + self.hidden_states_buffer, + ) + + # InferenceParams contains the keys and values cache. Refer to the + # original call in TEGemmaForCausalLM's `generate` method for more + # details. + self.inference_params = InferenceParams( + max_batch_size=self.config.cuda_graphs_static_batch_size, + max_sequence_length=self.config.cuda_graphs_static_max_context_len, + num_heads_kv=self.config.num_key_value_heads, + head_dim_v=self.config.head_dim, + head_dim_k=self.config.head_dim, + dtype=torch.bfloat16, + is_paged=self.config.is_paged, + page_size=16, + total_num_pages=self.config.cuda_graphs_static_batch_size + * self.config.cuda_graphs_static_max_context_len + // 16, + ) + + self._model_generation_phase.set_inference_params(self.inference_params) + self._model_context_phase.set_inference_params(self.inference_params) + + def record(self): + """ + Here "the trick" happens. `_model_context_phase` and + `_model_generation_phase` from TEGemmaForCausalLM are replaced with + their recorded version. Once the graphs are recorded, they can be + replayed with minimal usage of CPU and that leads to speedup. + """ + # Record the model with training=False, because it will be used in + # generation. + self.eval() + + # Setup the recording for context/prefill phase. + input_shape = ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + ) + + # Hardcoded value for the context length. + lengths = torch.tensor([9] * self.config.cuda_graphs_static_batch_size).to( + device="cuda", dtype=torch.int32 + ) + self.inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths))), lengths.tolist())) + ) + + # Record the graph for context/prefill phase. + self._model_context_phase = self.record_graph( + self._model_context_phase, + self.hidden_states_buffer, + attn_mask_type="padding_causal", + rope_emb=self.te_rope_emb, + ) + + # Setup the recording for generation phase. + input_shape = (self.config.cuda_graphs_static_batch_size, 1) + lengths = torch.tensor(input_shape[0] * [1], device="cuda", dtype=torch.int32) + self.inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths))), lengths.tolist())) + ) + + # Record the graph for generation phase. + self._model_generation_phase = self.record_graph( + self._model_generation_phase, + self.generation_buffer, + attn_mask_type="padding", + rope_emb=self.te_rope_emb, + ) + + def _create_or_fetch_hidden_states_buffer(self, *args, **kwargs): + """ + Overriden to make `hidden_states` static i.e. not change its pointer + in memory between every invocation. + + Returns the static buffer for `hidden states` which is already created + in the constructor. This is the same buffer as used in the + context/prefill phase. + """ + return self.hidden_states_buffer + + def _create_or_fetch_inference_params(self, *args, **kwargs): + """ + Overriden to make `inference_params` static i.e. not change its pointer + in memory between every invocation. + + Returns the static buffer for `inference_params` which is already created + in the constructor. + """ + self.inference_params.reset() + return self.inference_params + + @torch.no_grad() + def record_graph(self, function, input_tensor, **sample_kwargs): + """ + Records the graph for the given function. The function is invoked on + argument (self.hidden_states,) and all kernels are recorded. + It then returns the captured callable, which can be run later while + minimizing CPU usage. + """ + fp8_recipe = get_default_fp8_recipe() + + # We need both autocasts: FP8 for operations that can run in lower + # precision and BF16 for those that cannot. + with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False): + graphed_function = te.pytorch.make_graphed_callables( + function, + (input_tensor,), + fp8_enabled=self.config.fp8, + fp8_recipe=fp8_recipe, + allow_unused_input=True, + num_warmup_iters=5, + sample_kwargs=sample_kwargs, + ) + return graphed_function diff --git a/docs/examples/te_gemma/te_gemma_loading_weights.py b/docs/examples/te_gemma/te_gemma_loading_weights.py new file mode 100755 index 000000000..d0df9edc5 --- /dev/null +++ b/docs/examples/te_gemma/te_gemma_loading_weights.py @@ -0,0 +1,189 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import re +import gc +import torch + +from typing import List + +from transformer_engine.pytorch.fp8 import fp8_model_init + +from transformers.modeling_utils import load_state_dict +from transformers.utils.hub import get_checkpoint_shard_files + +""" + This file contains logic of mapping the HuggingFace GemmaModel parameters + with TransformerEngine TransformerLayer. When we have initialized Transformer models + both with HF and with TE, we can copy parameters from the first to the second. +""" + + +def _load_weights_for_fp8_model(vanilla_model, hyperparams): + """ + Loads weights and FP8 metadata from a calibrated weights file. + + The weights are in BF16 precision, but the state dict also contains + fp8 metadata computed by the calibration procedure. + """ + + fp8_metadata_sd = torch.load(hyperparams.fp8_model_weights_filename) + + # A hack to remove the extra state from the fp8_metadata_sd + # that contains the extra state from the core_attention module. + fp8_metadata_sd = { + k: v for k, v in fp8_metadata_sd.items() if "core_attention._extra_state" not in k + } + vanilla_model.load_state_dict( + fp8_metadata_sd, + strict=False, + # Because some parameters have multiple pointers to the same weight + # vanilla_model._model_context_phase.model and + # vanilla_model._model_generation_phase.model we need to load the + # weights in a non-strict manner. + ) + + +def _load_weights_for_standard_model(vanilla_model, config): + """ + Loads weights from the HuggingFace checkpoint. + """ + + archive_file = os.path.join(config.weights_cache_dir, "model.safetensors.index.json") + resolved_archive_file, _ = get_checkpoint_shard_files(config.weights_cache_dir, archive_file) + total_dict = {} + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + total_dict.update(state_dict) + + replace_params( + total_dict, + vanilla_model.state_dict(), + config, + qkv_fused_and_interleaved=config.fuse_qkv_params, + ) + # Copy remaining parameters like embedding. + vanilla_model.load_state_dict(total_dict, strict=False) + + # Force mem release. Taken from huggingface code. + del total_dict + gc.collect() + + +def load_te_model(cls, config): + """ + Loads the TE model with proper weights. + """ + + # Force the dtype to bfloat16 while loading the model. + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) + """ + Custom method adapted from `from_pretrained` method in HuggingFace + Transformers repo: + https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 + """ + config.use_cache = False # To make TransformerLayer compatible with GemmaModel + + # Loading model with FP8 only weights needs both the following context managers. + # 1. fp8_model_init(config.fp8_model_init) to tell TE to use FP8 only weights. + # 2. torch.no_grad() during TE modules' initilization so that they respect + # the `fp8_model_init` context manager. + with torch.no_grad(), fp8_model_init(config.fp8_model_init): + # Just create a model with random weights. + vanilla_model = cls(config).cuda() + + # Copy proper weights into the model. If loading weights with FP8 metadata, + # then the source weights are basically the same as the weights in the model. + # If not, then we need to load the weights from the HuggingFace checkpoint + # and do mapping of the weight names from HF to the TE model. + if config.fp8_model_weights_filename is not None: + _load_weights_for_fp8_model(vanilla_model, config) + else: + _load_weights_for_standard_model(vanilla_model, config) + + # Restore the original dtype. + torch.set_default_dtype(old_dtype) + return vanilla_model + + +def _get_all_layer_prefixes_to_update(hf_state_dict): + """ + There are many parameters in hf_state_dict, whose name start with "model.layers.[number]." + This function extracts all strings like "model.layers.[number]." + that are starting strings of keys in hf_state_dict. + """ + all_layer_prefixes = set() + for param_key in hf_state_dict.keys(): + layer_prefix_pat = "model.layers.\d+." + m = re.match(layer_prefix_pat, param_key) + if m is not None: + all_layer_prefixes.add(m.group()) + return all_layer_prefixes + + +def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False): + """ + Replaces params from TE TransformerLayer state_dict with corresponding parameters + from HuggingFace GemmaModel state_dict. + """ + all_layer_prefixes: List[str] = _get_all_layer_prefixes_to_update(hf_state_dict) + + for layer_prefix in all_layer_prefixes: + + def copy_from_ht_to_te(te_name, hf_name, start=None, end=None): + te_state_dict[layer_prefix + te_name].data[start:end].copy_( + hf_state_dict[layer_prefix + hf_name] + ) + + copy_from_ht_to_te( + "self_attention.layernorm_qkv.layer_norm_weight", "input_layernorm.weight" + ) + copy_from_ht_to_te("self_attention.proj.weight", "self_attn.o_proj.weight") + copy_from_ht_to_te("layernorm_mlp.layer_norm_weight", "post_attention_layernorm.weight") + copy_from_ht_to_te("layernorm_mlp.fc2_weight", "mlp.down_proj.weight") + copy_from_ht_to_te( + "layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size + ) + copy_from_ht_to_te( + "layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size + ) + + if qkv_fused_and_interleaved: + """ + When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor + in TE TransformerLayer. Moreover they are interleaved within each head. + Let q_i, k_i and v_i be query, key and value layers for i-th head respectively. + Then TE stores weight tensor in the form: + [q1 k1 v1 q2 k2 v2 ...] + This is done to maximally optimize performance time. + """ + te_qkv_layer = te_state_dict[layer_prefix + "self_attention.layernorm_qkv.weight"] + + def copy_interleave(hf_name, idx): + src = hf_state_dict[layer_prefix + hf_name] + for head_nr in range(config.num_attention_heads): + dst_offset = head_nr * config.head_dim * 3 + dst_slice = slice( + dst_offset + idx * config.head_dim, dst_offset + (idx + 1) * config.head_dim + ) + src_slice = slice( + head_nr * config.head_dim, head_nr * config.head_dim + config.head_dim + ) + te_qkv_layer[dst_slice, :] = src[src_slice, :] + + copy_interleave("self_attn.q_proj.weight", 0) + copy_interleave("self_attn.k_proj.weight", 1) + copy_interleave("self_attn.v_proj.weight", 2) + else: + copy_from_ht_to_te( + "self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight" + ) + copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight") + copy_from_ht_to_te( + "self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight" + ) + + return all_layer_prefixes diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb new file mode 100755 index 000000000..cc8675cfd --- /dev/null +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -0,0 +1,941 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "87e8360b-8d08-44bc-9333-79ba949afe8c", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "# Accelerating Hugging Face Gemma Inference with Transformer Engine" + ] + }, + { + "cell_type": "markdown", + "id": "2da33092-eef5-46a4-b222-0188cc6e5079", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "## Introduction\n", + "\n", + "Generative AI has made remarkable strides in recent years, with Large Language Models (LLMs) like ChatGPT at the forefront. These models have revolutionized how we interact with machine-generated content, providing capabilities that range from writing assistance to complex decision support. The core functionality of these models is the generation process, which involves predicting the next token in a sequence based on the preceding text. This task is critical for applications such as automated content creation, translation, and more, emphasizing the importance of efficient implementation.\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Animation 1: Hugging Face Gemma model token generation.\n", + "
\n", + "
\n", + "\n", + "For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n", + "\n", + "In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n", + "\n", + "This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n", + "\n", + "### 1. From vanilla KV-caching to Paged Attention for inference in Transformer Engine\n", + "\n", + "The original [Attention mechanism](https://arxiv.org/pdf/1706.03762) ushered in an era of Large Language Models, but the same attention mechanism, if used for deployment in inference scenarios, can be computationally wasteful. It is primarily due to a lot of redundant computation that happens in attention when the Transformer models are used autoregressively to compute the next token. Several tutorials on the internet explain in detail how KV Caching helps to reduce that redundant computation, e.g., [tutorial 1](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms), [tutorial 2](https://medium.com/@joaolages/kv-caching-explained-276520203249), etc.\n", + "\n", + "\n", + "Further, even though the performance benefit of KV Cache is immense, it comes at the cost of increased memory usage, which becomes a problem especially for longer context lengths. The major problems are: \n", + "\n", + "1. Internal fragmentation\n", + "2. External Fragmentation\n", + "\n", + "More information can be found in the [Paged Attention](https://arxiv.org/pdf/2309.06180) paper. The authors solve the above problems by treating the KV cache as a virtual memory with the actual physical blocks being much smaller than the overall cache size. This makes it easier to swap them in and out of GPU HBM as needed - very similar to how Operating Systems implement virtual memory to swap the individual pages in and out of the CPU RAM.\n", + "\n", + "\n", + "Transformer Engine allows users to use both \"Non-paged\" and \"Paged\" forms of KV Caching, and the results in this tutorial are posted for both use cases.\n", + "\n", + "\n", + "### 2. CUDA Graphs API\n", + "\n", + "The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to finish processing and then launch the kernels, which can lead to significant overhead. CUDA Graphs can address this issue. When such blocks of computation are executed repeatedly, CUDA Graphs allow us to record and replay them with less CPU involvement. This becomes particularly useful in applications like token generation, where multiple \"Transformer/Decoder Layers\" are run for every token that needs to be generated.\n", + "\n", + "One can read more about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n", + "\n", + "PyTorch exposes graphs via a raw `torch.cuda.CUDAGraph` class and two convenience wrappers: `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the CUDA graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 1: CUDA Graphs reduce the overhead generated by the long time it takes to launch a single kernel. It enables the recording and replaying of subsequent launches, thus reducing the total time used by the CPU.\n", + "
\n", + "
\n", + "\n", + "### 3. FP8 Scaling Factors Calibration\n", + "\n", + "This tutorial uses the `DelayedScaling` recipe for FP8 precision, which relies on the correct calculation of \"scaling factors\".\n", + "\n", + "If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `fp8_autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n", + "\n", + "It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", + "\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 2:\n", + "Assuming that the model is trained in FP32/BF16 precision and the goal is to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, the FP8 scaling parameters can be computed. This calibration allows the model to operate correctly in FP8 precision.\n", + "
\n", + "
\n", + "\n", + "### 4. FP8 Model Weights\n", + "\n", + "The typical approach is to store weights in higher precision and then cast them to FP8 before operations. This may prevent accuracy drops in training. However, for inference, this level of precision is not necessary.\n", + "\n", + "The Transformer Engine includes a wrapper `fp8_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 3: Model under fp8_autocast() stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using fp8_model_init() results in storing model weights in FP8 by default, which can help with these potential issues.\n", + "
\n", + "
\n", + "\n", + "### Benchmarking\n", + "\n", + "We'll evaluate the generation time across one benchmark: token generation with context/prefill phase max sequence length = 20, batch size = 64, and number of generated tokens = 492 on random texts with random lengths. This is a purely synthetic benchmark.\n", + "\n", + "
\n", + "Note\n", + " \n", + "This tutorial focuses on showcasing the mentioned features of the Transformer Engine in the context of token generation. It's important to note, however, that NVIDIA provides [TensorRT-LLM](https://docs.nvidia.com/tensorrt-llm/index.html), which is optimized for inference tasks and should be considered for such use cases.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "b18f91a9", + "metadata": {}, + "source": [ + "## Dependencies for this tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "e5201d77", + "metadata": {}, + "source": [ + "The following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_gemma.py`\n", + " - This file contains the code to load a Hugging Face Gemma checkpoint weights in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. Further, it contains necessary abstractions like a subclass of `GemmaForCausalLM` - `TEGemmaForCausalLM` that is used for generation with Transformer Engine's `TransformerLayer`, CUDA Graphs, and FP8 calibration for generation in FP8 precision.\n", + "2. `te_gemma_loading_weights.py`\n", + " - This file contains the logic of mapping the parameters from `GemmaDecoderLayer` into the `TransformerLayer`.\n", + "3. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training, and other miscellaneous tasks like restarting the Jupyter notebook from within the cell. \n", + "4. `requirements.txt`\n", + " - This file contains the necessary Python packages for this tutorial.\n", + "5. `media/`\n", + " - This directory contains the images and other artefacts used in this tutorial." + ] + }, + { + "cell_type": "markdown", + "id": "36767694-a1c5-4a00-a075-7addc55d8307", + "metadata": {}, + "source": [ + "### Setup and checks" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1de3351b-fa21-4b95-bb9e-d01ac8bb7edf", + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment and run this cell when running the tutorial for the first time\n", + "# %pip install -r requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c756ebbd-24c9-4a54-a381-e7c02c555206", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "import torch\n", + "cudnn_version = torch.backends.cudnn.version()\n", + "assert cudnn_version >= 90100, \"cuDNN version >= 9.1.0 is needed to run this tutorial.\"" + ] + }, + { + "cell_type": "markdown", + "id": "e8dfabbf", + "metadata": {}, + "source": [ + "## [Baseline] Running Hugging Face generation with Gemma model" + ] + }, + { + "cell_type": "markdown", + "id": "59560bff", + "metadata": {}, + "source": [ + "HuggingFace Transformers library offers generation API. \n", + "HuggingFace generation for the Gemma model will be used as a baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "2803e0ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing a lot of the same thing at the same time.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why GPUs are so good at graphics. The\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and builds advanced computer graphics and video processing chips for the PC and video game console markets.\n", + "* The company is a leading provider of graphics processing units (GPUs) for the PC and video game\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 46.60 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.batch_size = 64\n", + "run_config.max_seq_length = 512\n", + "\n", + "model = init_baseline_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "b3698dc6", + "metadata": {}, + "source": [ + "Let's put this time into the table for later comparison.\n", + "\n", + "| Models | Time | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 46.6 s | - |" + ] + }, + { + "cell_type": "markdown", + "id": "8bb40f45", + "metadata": {}, + "source": [ + "## [Optimization 1] Accelerating generation with Transformer Engine " + ] + }, + { + "cell_type": "markdown", + "id": "263b40f2", + "metadata": {}, + "source": [ + "Similar to the [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) finetuning tutorial, a `GemmaDecoderLayer` is substituted by a tuned `TransformerLayer` from the Transformer Engine library. Let's run it and compare the time with the baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9dceef93", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing a lot of the same thing at the same time.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why they are so good at graphics. The second\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n", + "* NVIDIA is the world leader in AI computing.\n", + "* NVIDIA is the world leader in graphics processing units (GP\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 12.25 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.batch_size = 64\n", + "run_config.max_seq_length = 512\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "b5d40836", + "metadata": {}, + "source": [ + "With just using Transformer Engine with default (non-paged) KV cache, a speedup of **3.8x** was obtained. Neat!" + ] + }, + { + "cell_type": "markdown", + "id": "006d18e8", + "metadata": {}, + "source": [ + "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n", + "|---|---|---|---|---|\n", + "| HF (baseline) | 46.6 s | - | - | - |\n", + "| TE (subsitution of `GemmaDecoderLayer` with `te.TransformerLayer`) | 12.25 s | 3.8x | 12.24 s | 3.8x |" + ] + }, + { + "cell_type": "markdown", + "id": "21a89d9c", + "metadata": {}, + "source": [ + "## [Optimization 2] More acceleration with CUDA Graphs" + ] + }, + { + "cell_type": "markdown", + "id": "e2d53e7b", + "metadata": {}, + "source": [ + "Transformer Engine includes a function `transformer_engine.pytorch.make_graphed_callables`, which behaves similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from [te_gemma.py](./te_gemma.py) from class `TEGemmaForCausalLMCudaGraphs`:\n", + "```python\n", + " def __init__(self, config : GemmaConfig):\n", + " \"\"\"\n", + " Here \"the trick\" happens. `_model_context_phase` and\n", + " `_model_generation_phase` from TEGemmaForCausalLM are replaced with\n", + " their recorded version. Once the graphs are recorded, they can be\n", + " replayed with minimal usage of CPU and that leads to speedup.\n", + " \"\"\"\n", + " (...)\n", + " # Record the graph for context/prefill phase.\n", + " self._model_context_phase = \n", + " self.record_graph(self._model_context_phase, self.hidden_states_buffer)\n", + "\n", + " (...) \n", + " # Record the graph for generation phase.\n", + " self._model_generation_phase = \n", + " self.record_graph(self._model_generation_phase, self.generation_buffer)\n", + "\n", + " @torch.no_grad()\n", + " def record_graph(self, function, input_tensor):\n", + " \"\"\"\n", + " Records the graph for the given function. The function is invoked on\n", + " argument (self.hidden_states,) and all kernels are recorded.\n", + " It then returns the captured callable, which can be run later while\n", + " minimizing CPU usage.\n", + " \"\"\"\n", + " fp8_recipe = get_default_fp8_recipe()\n", + "\n", + " # We need both autocasts: FP8 for operations that can run in lower\n", + " # precision and BF16 for those that cannot.\n", + " with autocast(\"cuda\", dtype=torch.bfloat16, cache_enabled=False):\n", + " graphed_function = te.pytorch.make_graphed_callables(\n", + " function,\n", + " (input_tensor,),\n", + " fp8_enabled=self.config.fp8,\n", + " fp8_recipe=fp8_recipe,\n", + " allow_unused_input=True,\n", + " num_warmup_iters=5,\n", + " sample_kwargs=sample_kwargs,\n", + " )\n", + " return graphed_function\n", + "```\n", + "\n", + "It is strongly recommended to review the entire code of the class `TEGemmaForCausalLMCudaGraphs`. Let's now proceed to evaluate the performance improvement offered by CUDA Graphs.\n", + "\n", + "*Note the usage of static buffers and corresponding configuration in the following cell, which is necessary for CUDA Graphs to function.*" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "31a3a8a3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing a lot of the same thing at the same time.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why they are so good at graphics. The second\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n", + "* NVIDIA is the world leader in AI computing.\n", + "* NVIDIA is the world leader in graphics processing units (GP\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 6.39 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.max_seq_length = 512\n", + "run_config.batch_size = 64\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "# It is necessary to preallocate a static buffer.\n", + "# CUDA graphs require static input tensors for every kernel.\n", + "# This approach may result in a slight increase in memory consumption;\n", + "# however, the substantial speedup achieved makes it worthwhile.\n", + "run_config.generation_cuda_graphs = True\n", + "run_config.cuda_graphs_static_batch_size = 64\n", + "run_config.cuda_graphs_static_max_seq_len = 512\n", + "run_config.cuda_graphs_static_max_context_len = 512\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "53bb430f", + "metadata": {}, + "source": [ + "A speed up of **7.2x** was obtained by using CUDA Graphs with TE's `TransformerLayer`.\n", + "\n", + "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n", + "|---|---|---|---|---|\n", + "| HF (baseline) | 46.6 s | - | - | - |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n", + "| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |" + ] + }, + { + "cell_type": "markdown", + "id": "0a11b75c", + "metadata": {}, + "source": [ + "Let's profile the code from one of the cells above, which runs generation with the Gemma model, and examine the resulting traces in [NVIDIA Nsight Systems](https://developer.nvidia.com/nsight-systems) to understand the performance characteristics and sources of speedup. A few things to recap:\n", + "\n", + "1. For the TE Gemma model implementation, `model.generate()` internally calls `model_context_phase` and `model_generation_phase`.\n", + "2. They are just wrappers around the Gemma model's layers, and they are graphed separately when CUDA graphs are enabled.\n", + "3. So, for each token generated (after the first token), a single invocation of `model_generation_phase` happens as a complete CUDA graph. \n", + "4. The following illustration zooms in on a single `TransformerLayer` layer forward pass (within the larger `model_generation_phase` graphed callable) for clarity.\n", + "\n", + "(For details, refer to the implementation in [te_gemma.py](./te_gemma.py))\n", + "\n", + "
\n", + "\n", + "
\n", + " \n", + "Figure 4: (Without CUDA graphs) Blue blobs in the top figure are GPU kernels, and whitespace b/w those indicates that GPUs are idle waiting for the CPU to finish processing and then launch kernels. (With CUDA graphs) The whitespace gets virtually eliminated because all the GPU kernels are bundled into a single highly optimized unit of work with no CPU time in between. (Note that for reference, the kernels are mapped across both cases, and the sizes of those kernels only seem different because of the presence of large voids in the former case, but the sizes are actually the same.)\n", + "
\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "id": "e6b171a0", + "metadata": {}, + "source": [ + "## [Optimization 3] Even more acceleration with FP8 precision " + ] + }, + { + "cell_type": "markdown", + "id": "1a80288b", + "metadata": {}, + "source": [ + "### Calibrating FP8 scaling factors for correctness\n", + "\n", + "Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `fp8_autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n", + "\n", + "1. Model weight tensors\n", + "2. Input tensors\n", + "\n", + "If the model is run in FP8 precision with incorrect scaling factors, the resulting FP8-cast model weights and FP8-cast inputs (both converted from BF16 precision) will be significantly misaligned, potentially leading to large errors and inaccurate results.\n", + "\n", + "To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n", + "\n", + "*Note that other tensors might need calibration in specific use-cases, but for the generation process in this tutorial, calibrating only the input and weight tensors is needed, and so only the forward pass is considered.*\n", + " \n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 5: The default FP8 scaling factors are incorrect, and so the BF16 to FP8 conversion, as is, can lead to numerical errors. Calibration allows for collecting statistics/metadata about the input and weight tensors in higher precision during the forward pass.\n", + "
\n", + "
\n", + "\n", + "\n", + "The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, the model is saved, and these weights will be utilized in subsequent steps." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "aecee0e1", + "metadata": {}, + "outputs": [], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "import transformer_engine.pytorch as te\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "run_config.fuse_qkv_params = True\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "# Calibration\n", + "with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast(\n", + " device_type=\"cuda\", dtype=torch.bfloat16\n", + "):\n", + " model.train()\n", + " run_forward_pass(model, run_config, num_iters=64)\n", + "\n", + "# Compute scale_fwd with enabled fp8 autocast\n", + "with te.fp8_autocast(enabled=True), torch.autocast(\n", + " device_type=\"cuda\", dtype=torch.bfloat16\n", + "):\n", + " run_forward_pass(model, run_config, 1)\n", + "\n", + "# Some parameters are in pointing to the same tensors, double save is avoided here.\n", + "dict_to_save = {\n", + " k: v\n", + " for k, v in model.state_dict().items()\n", + " if (\"_context_phase\" not in k and \"_generation_phase\" not in k)\n", + "}\n", + "torch.save(\n", + " dict_to_save, \"calibrated_weights.pth\"\n", + ") # <-- Add path to save calibrated weights." + ] + }, + { + "cell_type": "markdown", + "id": "b6dcd135", + "metadata": {}, + "source": [ + "### Generation with better FP8 scaling factors\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 6: After the calibration process, FP8 scaling factors are correct and prevent numerical errors.\n", + "
\n", + "
\n", + "\n", + "Now that the calibration has produced correct scaling factors, FP8 inference is ready to be run." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a913f54d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n", + "* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA is a key player\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 8.73 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.fuse_qkv_params = True # This is needed by the last improvement.\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "# CUDA Graphs related config\n", + "run_config.generation_cuda_graphs = True\n", + "run_config.cuda_graphs_static_batch_size = 64\n", + "run_config.cuda_graphs_static_max_seq_len = 512\n", + "run_config.cuda_graphs_static_max_context_len = 512\n", + "\n", + "# Enable FP8\n", + "run_config.fp8 = True\n", + "# Calibrated fp8 weights are loaded directly from the file.\n", + "run_config.fp8_model_weights_filename = (\n", + " \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n", + ")\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "8cdbb56c", + "metadata": {}, + "source": [ + "One can observe that the outputs are coherent; however, the generation time has increased. Why is this the case?\n", + "\n", + "### Use of FP8-only model weights\n", + "\n", + "Running the model in FP8 precision does not imply that the weights are stored in FP8. By default, they are stored in higher precision and are cast to FP8, using saved scaling factors before GEMM operations (matrix multiplications).\n", + "\n", + "This approach is appropriate during training since gradients during the backward pass are produced in higher precision, and therefore, having higher precision copies of model weights helps, as they have enough dynamic range to encompass incoming information from the gradients. During the forward pass, the higher precision model weights and the batch inputs are cast to FP8, and the GEMMs occur in FP8 precision, which helps save training time overall if the time saved from running GEMM in FP8 precision (than in higher precision) is more than the extra time spent during the cast operation.\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 7: Running the model at higher precision involves only one operation - GEMM. However, when the model operates in FP8, it requires casting inputs to the GEMM - namely, model weights and batch inputs from higher precision to FP8, which involves extra kernels in addition to the low-precision GEMM kernel.\n", + "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "626aefa1-d5c4-4d8f-88d9-7d7943afde0d", + "metadata": {}, + "source": [ + "However, things change during inference. Since the weights need no update and remain frozen, higher precision copies of weights could be avoided completely. It is possible to cast the higher precision weights only once to FP8 precision while initializing the model with appropriate scaling factors and then use those FP8-only copies of weights during the entirety of token generation. This provides two-fold benefits:\n", + "\n", + "1. Lower memory usage - since the model weights are stored in FP8 precision only (compared to training, where both BF16 and FP8 copies end up being present in the memory during peak usage).\n", + "2. Faster forward pass - since there is no cast kernel to cast higher precision weights to FP8 every time before a GEMM operation. (Unless the inputs are in FP8 precision already, there's still one cast kernel to cast inputs to FP8 precision.) \n", + "\n", + "\n", + "Transformer Engine supports maintaining FP8-only weights with the `fp8_model_init` context manager. Let's see a small example:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4562ee82-8c95-4736-8815-cd386078a485", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory required for 16384x16384 linear layer: \n", + "FP32 - 1024.0 MB, \n", + "BF16 - 512.0 MB, \n", + "FP8 - 256.0 MB, \n", + "\n", + "Actual GPU memory usage with a TE FP32 linear layer: 1024.06 MB\n", + "Actual GPU memory usage with a TE BF16 linear layer: 512.03 MB\n", + "Actual GPU memory usage with a TE FP8 linear layer: 256.08 MB\n" + ] + } + ], + "source": [ + "import torch\n", + "import transformer_engine.pytorch as te\n", + "\n", + "H = 2**14\n", + "D = 2**14\n", + "print(f\"Memory required for {H}x{D} linear layer: \\n\"\n", + " f\"FP32 - {H*D*4/1024**2} MB, \\n\"\n", + " f\"BF16 - {H*D*2/1024**2} MB, \\n\"\n", + " f\"FP8 - {H*D*1/1024**2} MB, \\n\")\n", + "\n", + "linear_fp32 = te.Linear(H, D, params_dtype=torch.float32) \n", + "print(f\"Actual GPU memory usage with a TE FP32 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", + "del linear_fp32\n", + "\n", + "linear_bf16 = te.Linear(H, D, params_dtype=torch.bfloat16)\n", + "print(f\"Actual GPU memory usage with a TE BF16 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", + "del linear_bf16\n", + "\n", + "# Initialize model weights in FP8 precision\n", + "with torch.no_grad(), te.fp8_model_init(enabled=True):\n", + " linear_fp8 = te.Linear(H, D)\n", + "print(f\"Actual GPU memory usage with a TE FP8 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", + "del linear_fp8" + ] + }, + { + "cell_type": "markdown", + "id": "2a26aba9-f3ba-42c4-b4c3-9e845502ae1b", + "metadata": {}, + "source": [ + "\n", + "
\n", + "\n", + "
\n", + " Figure 8: Using fp8_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n", + "
\n", + "
\n", + "\n", + "Let's run the code with `fp8_model_init`:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "96264b9c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n", + "* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA is a key player\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 4.99 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.fuse_qkv_params = True # This is needed by the last improvement.\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "# CUDA Graphs related config\n", + "run_config.generation_cuda_graphs = True\n", + "run_config.cuda_graphs_static_batch_size = 64\n", + "run_config.cuda_graphs_static_max_seq_len = 512\n", + "run_config.cuda_graphs_static_max_context_len = 512\n", + "\n", + "# Enable FP8 math and FP8 model weights\n", + "run_config.fp8 = True\n", + "run_config.fp8_model_init = True # This will result in storing only fp8 weights.\n", + "run_config.fp8_model_weights_filename = (\n", + " \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n", + ")\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "3e30ca5a", + "metadata": {}, + "source": [ + "The final speedup is **9.3x**. \n", + "\n", + "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n", + "|---|---|---|---|---|\n", + "| HF (baseline) | 46.6 s | - | - | - |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n", + "| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |\n", + "| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `fp8_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |" + ] + }, + { + "cell_type": "markdown", + "id": "c6e87275", + "metadata": {}, + "source": [ + "## Conclusions" + ] + }, + { + "cell_type": "markdown", + "id": "7bb2452d", + "metadata": {}, + "source": [ + "This tutorial focuses primarily on making the token generation faster with an off-the-shelf model downloaded from Hugging Face using the following features of the Transformer Engine:\n", + "\n", + "1. Support for KV Caching (both non-paged and paged),\n", + "2. Integration with CUDA Graphs,\n", + "3. FP8 scaling factors calibration,\n", + "4. Keeping model parameters in FP8 precision.\n", + "\n", + "It's worth noting that these features in TE are also readily applicable to other use-cases which haven't been extensively talked about in the tutorial: \n", + "\n", + "1. Longer context lengths (with paged KV cache) \n", + "2. Using less memory during generation (by storing weights in FP8 precision using `fp8_model_init`)\n", + "\n", + "Readers are encouraged to explore these use cases by playing around with this tutorial, especially with larger models." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py new file mode 100755 index 000000000..cc31afc65 --- /dev/null +++ b/docs/examples/te_gemma/utils.py @@ -0,0 +1,370 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import sys +import IPython +import random +import string + +from te_gemma_loading_weights import load_te_model +import torch +from torch.utils.data import DataLoader + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + AutoConfig, +) +from transformers import DataCollatorForLanguageModeling +from datasets import load_dataset + + +from te_gemma import TEGemmaForCausalLM, TEGemmaForCausalLMCudaGraphs + +random.seed(42) +torch.manual_seed(42) + + +class RunConfiguration: + def __init__(self): + self.mixed_precision = "bf16" + self.model_name = None + + # FP8 precision settings + self.fp8 = False + self.fp8_model_weights_filename = None + self.fp8_model_init = False + + # Cuda graphs + self.generation_cuda_graphs = False + self.cuda_graphs_static_batch_size = 64 + self.cuda_graphs_static_max_seq_len = 512 + self.cuda_graphs_static_max_context_len = 512 + + # Finetuning/calibration/generation settings + self.dataset_name = "timdettmers/openassistant-guanaco" + self.dataset_text_field = "text" + self.learning_rate = 1.41e-5 + self.batch_size = 64 + self.max_seq_length = 512 + self.gradient_accumulation_steps = 1 + self.num_warmup_steps = 5 + self.num_training_steps = 10 + + # Coalesced QKV params or not + self.fuse_qkv_params = False + + # Attention + self.is_paged = False + + # This is either provided by the user or it will be set when the + # model weights are downloaded. + self.weights_cache_dir = "" + + +# Global variable for the run configuration so that it can be easily accessed +# throughout the jupyter notebook with an `import * from utils` statement +run_config = RunConfiguration() + + +def get_dataloaders(run_config): + """ + Returns a basic dataloader for the dataset which contains tokenized batches + of text. + """ + dataset = load_dataset(run_config.dataset_name, split="train") + tokenizer = AutoTokenizer.from_pretrained(run_config.model_name) + + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + def tokenize(element): + outputs = tokenizer( + element["text"], + truncation=True, + padding=False, + max_length=run_config.max_seq_length, + return_overflowing_tokens=False, + return_length=False, + ) + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + + # Tokenize the dataset + dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names) + + # Simply pad to the multiple of 16 for both FP8 and BF16 precision + pad_to_multiple_of = 16 + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + pad_to_multiple_of=pad_to_multiple_of, + ) + + dataloader_params = { + "batch_size": run_config.batch_size, + "collate_fn": data_collator, + "drop_last": True, + } + train_dataloader = DataLoader(dataset, **dataloader_params) + return train_dataloader + + +def ensure_model_is_downloaded(run_config): + """ + Downloads and caches the model weights if not already downloaded. A valid + Huggingface Access Token is required to download the model weights. + """ + assert run_config.model_name in [ + "google/gemma-7b", + ], "Only Gemma 7B model is supported!" + + # Login using Huggingface Hub API + from huggingface_hub import login + + try: + login(run_config.hf_access_token) + except Exception as e: + if "Invalid token passed!" in str(e): + print( + "Please pass a valid HF Access Token! More info at" + " https://huggingface.co/docs/hub/en/security-tokens." + ) + else: + print(f"Exception is {e}") + + # Download the model if it doesn't exist + from huggingface_hub import snapshot_download + + supplied_cache_dir = ( + run_config.weights_cache_dir if run_config.weights_cache_dir != "" else None + ) + run_config.weights_cache_dir = snapshot_download( + repo_id=run_config.model_name, cache_dir=supplied_cache_dir + ) + + +def init_baseline_model(run_config): + """ + Initializes a baseline HF Gemma model with the model name provided in + the run_config. + """ + + # Download and cache the weights if not already downloaded + ensure_model_is_downloaded(run_config) + + # Init the model + config = AutoConfig.from_pretrained(run_config.model_name) + + # Make sure to use flash_attention to do iso comparison with TEGemmaModel + config._attn_implementation = "flash_attention_2" + model = AutoModelForCausalLM.from_pretrained( + run_config.model_name, + config=config, + torch_dtype=torch.bfloat16, + ).cuda() + + return model + + +def init_te_gemma_model(run_config): + """ + Initializes a Gemma model with `GemmaDecoderLayer`s swapped with + `TransformerLayer`s from TransformerEngine. In case CUDA Graphs are enabled, + the model is initialized from `TEGemmaForCausalLMCudaGraphs` class. + """ + + # Download and cache the weights if not already downloaded + ensure_model_is_downloaded(run_config) + + cls = TEGemmaForCausalLMCudaGraphs if run_config.generation_cuda_graphs else TEGemmaForCausalLM + config = AutoConfig.from_pretrained(run_config.model_name) + + # Inject all fields from the `run_config` to the model `config` to make the + # code simpler. + for key, value in run_config.__dict__.items(): + setattr(config, key, value) + + # Initialize the model and move it to the GPU. + model = load_te_model(cls, config).cuda() + + # Record the model if CUDA Graphs are enabled. + if run_config.generation_cuda_graphs: + model.record() + + return model + + +def restart_jupyter_notebook(): + # Try restarting the Jupyter kernel + IPython.Application.instance().kernel.do_shutdown(True) + + # Check whether the device memory has been flushed + if torch.cuda.memory_allocated() != 0: + import warnings + + warnings.warn("The device memory hasn't been flushed, trying with a second method!") + + # Try restarting the Jupyter kernel another way + # Restart the kernel + from IPython.core.display import HTML + + HTML("") + + if torch.cuda.memory_allocated() != 0: + print( + "The device memory hasn't been flushed, try manually restarting the Jupyter kernel!" + ) + + # Suppress the warnings + if not sys.warnoptions: + import warnings + + warnings.simplefilter("ignore") + torch.set_warn_always(False) + + +@torch.no_grad() +def run_forward_pass(model, run_config, num_iters): + """ + Runs the forward pass of the model with sample data. Intended to use for + warmup and/or calibration. + """ + train_dataloader = get_dataloaders(run_config) + + model.train() + train_dataloader = enumerate(train_dataloader) + + for _ in range(num_iters): + _, batch = next(train_dataloader) + batch["input_ids"] = batch["input_ids"].cuda() + batch["attention_mask"] = batch["attention_mask"].cuda() + model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) + + +############################################################################### +# Benchmarking and example generation functions. +############################################################################### + + +def print_sample_of_generated_texts(model, run_config): + """ + Prints a sample of generated texts from the input model. + """ + + tokenizer = AutoTokenizer.from_pretrained(run_config.model_name) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + prompts = [ + "Here are the two facts about GPUs:", + "Some facts about NVIDIA:", + "The fundamental theorem of calculus for the layman:", + "A fact about AI:", + ] + + # Repeat prompts to match batch size + prompts *= run_config.batch_size // len(prompts) + inputs = tokenizer(prompts, return_tensors="pt", padding=True) + + max_total_tokens = ( + run_config.max_seq_length + if not run_config.generation_cuda_graphs + else run_config.cuda_graphs_static_max_seq_len + ) + + max_length = inputs["input_ids"].size(1) + new_length = ((max_length + 63) // 64) * max_total_tokens + + # Add padding to the left + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], (new_length - max_length, 0), value=tokenizer.pad_token_id + ) + + # Add padding to the left (only intended for baseline generation with HF + # which expects padding to the left) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (new_length - max_length, 0), value=0 + ) + + inputs["input_ids"] = inputs["input_ids"].cuda() + inputs["attention_mask"] = inputs["attention_mask"].cuda() + + outputs = model.generate(**inputs, max_new_tokens=50) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + def print_output(prompts, generated_texts, idx): + print("=" * 30 + f" Generation example {idx+1} " + "=" * 30) + print(f'Prompt: "{generated_texts[idx][: len(prompts[idx])]}"') + print(f'Generated text: "{generated_texts[idx][len(prompts[idx]) :]}"') + + # Print the output from first two prompts + for i in range(2): + print_output(prompts, generated_texts, i) + + +def _generate_random_words(num_words, max_word_length): + """ + Generates random words for the benchmark. + """ + + words = [] + for _ in range(num_words): + word_length = random.randint(1, max_word_length) + word = "".join(random.choices(string.ascii_lowercase, k=word_length)) + words.append(word) + return words + + +def benchmark_generation(model, run_config, context_length=20): + """ + Benchmarks the generation time for a random input to the model. + """ + + batch_size = run_config.batch_size + + max_total_tokens = ( + run_config.max_seq_length + if not run_config.generation_cuda_graphs + else run_config.cuda_graphs_static_max_seq_len + ) + max_new_tokens = max_total_tokens - context_length + + print("\n" + "=" * 80) + print( + f"Benchmarking for batch_size = {batch_size}, prefill tokens =" + f" {context_length} and max new tokens = {max_new_tokens}" + ) + + input_str = _generate_random_words(batch_size, context_length) + + tokenizer = AutoTokenizer.from_pretrained(run_config.model_name) + inputs = tokenizer(input_str, return_tensors="pt", padding=True) + + max_context_tokens = inputs["input_ids"].size(1) + + # Add padding to the left + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (max_total_tokens - max_context_tokens, 0), + value=tokenizer.pad_token_id, + ) + + # Add padding to the left (only intended for baseline generation with HF + # which expects padding to the left) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (max_total_tokens - max_context_tokens, 0), value=0 + ) + + inputs["input_ids"] = inputs["input_ids"].cuda() + inputs["attention_mask"] = inputs["attention_mask"].cuda() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + + model.generate(inputs["input_ids"].cuda(), max_new_tokens=max_new_tokens) + torch.cuda.synchronize() + end.record() + + print(f"Time: {start.elapsed_time(end)/1000:.2f} s.") diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index 7013e85ec..00499cff5 100644 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -5,7 +5,7 @@ "id": "6a5b2993", "metadata": {}, "source": [ - "# Accelerating a Hugging Face Llama 2 and Llama 3 models with Transformer Engine\n", + "# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n", "\n", "
\n", "\n", diff --git a/docs/index.rst b/docs/index.rst index bbdb4fea6..2c04810f4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -46,6 +46,8 @@ Transformer Engine documentation examples/fp8_primer.ipynb examples/advanced_optimizations.ipynb examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb + examples/te_gemma/tutorial_generation_gemma_with_te.ipynb + examples/onnx/onnx_export.ipynb .. toctree:: :hidden: diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index b761b1381..3855db275 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -269,7 +269,10 @@ def train_and_evaluate(args): ) as mesh, te.fp8_autocast( enabled=args.use_fp8, fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), + mesh_resource=te.MeshResource( + dp_resource=DEVICE_DP_AXIS, + tpsp_resource=DEVICE_TP_AXIS, + ), ): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 26740c025..d6bfddb3e 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -266,7 +266,7 @@ def train_and_evaluate(args): ) as mesh, te.fp8_autocast( enabled=args.use_fp8, fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None), + mesh_resource=te.MeshResource(dp_resource=DEVICE_DP_AXIS), ): rng = jax.random.PRNGKey(args.seed) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index e8a14a146..420e36ea1 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -384,7 +384,10 @@ def train_and_evaluate(args): ) as mesh, te.fp8_autocast( enabled=args.use_fp8, fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), + mesh_resource=te.MeshResource( + dp_resource=DEVICE_DP_AXIS, + tpsp_resource=DEVICE_TP_AXIS, + ), ): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index a9ded61b8..2c5bd7025 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -221,7 +221,9 @@ def train_and_evaluate(args): else: fp8_recipe = None - with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe): + with te.fp8_autocast( + enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() + ): encoder = Net(num_embed) # We use nn.Embed, thus inputs need to be in int inputs = jnp.zeros(input_shape, dtype=jnp.int32) diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 110705d01..92baf4b0c 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -193,7 +193,9 @@ def train_and_evaluate(args): else: fp8_recipe = None - with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe): + with te.fp8_autocast( + enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() + ): cnn = Net(args.use_te) var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16)) tx = optax.sgd(args.lr, args.momentum) diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index e510df176..d52e97d65 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -263,7 +263,13 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) te.module.base.initialize_ub( [batched_size, hidden_size], tp_size, - use_fp8=opts.fp8, + quantization_modes=[ + ( + te.module.base.UserBufferQuantizationMode.FP8 + if opts.fp8 + else te.module.base.UserBufferQuantizationMode.NONE + ) + ], dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, ) diff --git a/hipify_custom_map.json b/hipify_custom_map.json index 8773c233e..97824bbdb 100644 --- a/hipify_custom_map.json +++ b/hipify_custom_map.json @@ -5,7 +5,8 @@ "util/cuda_runtime.h" : "util/hip_runtime.h", "ATen/cudnn/Handle.h" : "ATen/miopen/Handle.h", "CUfunc_cache" : "hipFuncCache_t", - "" : "" + "" : "", + "cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)" } } diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 3d00e0346..e4a3f4630 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -25,7 +25,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" @@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" # Test without custom calls export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -NVTE_JAX_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" +NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L0_jax_wheel/test.sh b/qa/L0_jax_wheel/test.sh index e1400b10b..bf9e4a461 100644 --- a/qa/L0_jax_wheel/test.sh +++ b/qa/L0_jax_wheel/test.sh @@ -26,23 +26,24 @@ pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine- VERSION=`cat $TE_PATH/build_tools/VERSION.txt` WHL_BASE="transformer_engine-${VERSION}" + # Core wheel. -NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel" -wheel unpack dist/* || error_exit "Failed to unpack dist/*" +NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation -vvv --wheel-dir ./dist . || error_exit "Failed to setup bdist_wheel" +wheel unpack dist/${WHL_BASE}-* || error_exit "Failed to unpack dist/${WHL_BASE}-*.whl" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" rm dist/*.whl || error_exit "Failed to remove dist/*.whl" mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" -NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage" +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup metapackage" cd transformer_engine/jax -NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist" +NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup sdist" -pip3 install dist/* || error_exit "Failed to install dist/*" +pip3 install --no-build-isolation --no-deps -vvv dist/* || error_exit "Failed to install dist/*" cd $TE_PATH -pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps" +pip3 install --no-build-isolation --no-deps -vvv dist/*.whl || error_exit "Failed to install dist/*.whl --no-deps" python3 $TE_PATH/tests/jax/test_sanity_import.py || test_fail "test_sanity_import.py" diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index c94edba2b..b4bf0a024 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -14,14 +14,23 @@ FAIL=0 +# It is not installed as a requirement, +# because it is not available on PyPI. +pip uninstall -y nvdlfw-inspect +pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git + pip install pytest==8.2.1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 + # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 -NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 exit $FAIL diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 7fe439b37..394273ca4 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -23,33 +23,30 @@ set -x mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime" -pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions" - -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" -NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" -NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" + +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" +NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" +NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L0_pytorch_wheel/test.sh b/qa/L0_pytorch_wheel/test.sh index ffd5ca290..3056547ef 100644 --- a/qa/L0_pytorch_wheel/test.sh +++ b/qa/L0_pytorch_wheel/test.sh @@ -27,22 +27,22 @@ VERSION=`cat $TE_PATH/build_tools/VERSION.txt` WHL_BASE="transformer_engine-${VERSION}" # Core wheel. -NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel" -wheel unpack dist/* || error_exit "Failed to unpack dist/*" +NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation -vvv --wheel-dir ./dist . || error_exit "Failed to setup bdist_wheel" +wheel unpack dist/${WHL_BASE}-* || error_exit "Failed to unpack dist/${WHL_BASE}-*.whl" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" rm dist/*.whl || error_exit "Failed to remove dist/*.whl" mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" -NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage" +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup metapackage" cd transformer_engine/pytorch -NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist" +NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup sdist" -pip3 install dist/* || error_exit "Failed to install dist/*" +pip3 install --no-build-isolation --no-deps -vvv dist/* || error_exit "Failed to install dist/*" cd $TE_PATH -pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps" +pip3 install --no-build-isolation --no-deps -vvv dist/*.whl || error_exit "Failed to install dist/*.whl --no-deps" python3 $TE_PATH/tests/pytorch/test_sanity_import.py || test_fail "test_sanity_import.py" diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh new file mode 100755 index 000000000..e074b46ae --- /dev/null +++ b/qa/L1_cpp_distributed/test.sh @@ -0,0 +1,17 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +# Find TE +: ${TE_PATH:=/opt/transformerengine} +TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}') +export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH + +if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then + cd $TE_PATH/tests/cpp_distributed + cmake -GNinja -S. -Bbuild + cmake --build build + mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm +fi diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 5deb77af9..8ecc5a917 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -8,4 +8,5 @@ set -xe : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* +NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* +SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 09ef661c4..7f061d222 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -21,14 +21,21 @@ FAILED_CASES="" mkdir -p "$XML_LOG_DIR" +# It is not installed as a requirement, +# because it is not available on PyPI. +pip uninstall -y nvdlfw-inspect +pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git + pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh new file mode 100644 index 000000000..1486d5097 --- /dev/null +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + +pip3 install onnxruntime==1.20.1 +pip3 install onnxruntime_extensions==0.13.0 + +: ${TE_PATH:=/opt/transformerengine} + +python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/qa/L2_jax_distributed_unittest/test.sh b/qa/L2_jax_distributed_unittest/test.sh new file mode 100644 index 000000000..0b7372650 --- /dev/null +++ b/qa/L2_jax_distributed_unittest/test.sh @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -xe + +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + +NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* diff --git a/qa/L2_jax_unittest/test.sh b/qa/L2_jax_unittest/test.sh index c5c193351..f933a0732 100644 --- a/qa/L2_jax_unittest/test.sh +++ b/qa/L2_jax_unittest/test.sh @@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" # Test without custom calls export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -NVTE_JAX_CUSTOM_CALLS_RE="" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" +NVTE_JAX_CUSTOM_CALLS="false" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 547849e95..7e9616cd0 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -41,6 +41,6 @@ do fi # Run tests - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py done diff --git a/setup.py b/setup.py index f002e2edf..1ae476311 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,7 @@ """Installation script.""" +from importlib import metadata import os import time from pathlib import Path @@ -23,6 +24,7 @@ all_files_in_dir, hipify, cuda_archs, + cuda_version, get_frameworks, remove_dups, ) @@ -112,6 +114,18 @@ def setup_common_extension() -> CMakeExtension: if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") + if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))): + cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON") + cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution( + f"nvidia-cublasmp-cu{cuda_version()[0]}" + ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}") + cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") + nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution( + f"nvidia-nvshmem-cu{cuda_version()[0]}" + ).locate_file("nvidia/nvshmem") + cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}") + print("CMAKE_FLAGS:", cmake_flags[-2:]) + # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") if nvte_cmake_extra_args: diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 75c52fdd7..b71addebf 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -64,7 +64,7 @@ else() project(transformer_engine_tests LANGUAGES HIP CXX) # Ask hcc to generate device code during compilation so we can use # host linker to link. - set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted -Wno-unused-result -ftemplate-backtrace-limit=0") + set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted -Wno-unused-result -Wno-unused-value -ftemplate-backtrace-limit=0") endif() add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest) @@ -85,6 +85,7 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_ message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) include_directories(../../transformer_engine/common) +include_directories(../../transformer_engine) include_directories(${CMAKE_SOURCE_DIR}) if(USE_CUDA) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index e3af4a360..46bcf4242 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -7,6 +7,8 @@ list(APPEND test_cuda_sources test_cast.cu test_cast_current_scaling.cu + test_cast_dbias.cu + test_cast_dbias_dgelu.cu test_cast_gated_swiglu.cu test_cast_mxfp8_gated_swiglu.cu test_qdq.cu @@ -18,8 +20,6 @@ list(APPEND test_cuda_sources test_cast_transpose_dbias.cu test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu - test_cast_dbias.cu - test_cast_dbias_dgelu.cu test_act.cu test_normalization.cu test_normalization_mxfp8.cu @@ -28,6 +28,7 @@ list(APPEND test_cuda_sources test_multi_unpadding.cu test_causal_softmax.cu test_swizzle.cu + test_swap_first_dims.cu ../test_common.cu) if(USE_CUDA) list(APPEND test_cuda_sources diff --git a/tests/cpp/operator/test_cast_current_scaling.cu b/tests/cpp/operator/test_cast_current_scaling.cu index 856c24cfc..89ae79f9a 100644 --- a/tests/cpp/operator/test_cast_current_scaling.cu +++ b/tests/cpp/operator/test_cast_current_scaling.cu @@ -219,7 +219,7 @@ TEST(AmaxConsistencyTest, AtomicVsWorkspace) { // Path 2: two-stage amax using workspace std::vector ws_shape{N}; Tensor workspace("workspace", ws_shape, DType::kFloat32); - nvte_compute_amax_with_workspace(input.data(), out_ws.data(), workspace.data(), 0); + nvte_compute_amax_with_workspace(input.data(), out_ws.data(), workspace.data(), nullptr, 0); cudaDeviceSynchronize(); auto err = cudaGetLastError(); diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index b0a847d7d..b635dc00b 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -38,98 +38,37 @@ enum ActivationType { SReLU }; -template -void scale_block(const ProcessingMethod processing_method, +template +void compute_ref(const ProcessingMethod processing_method, + float (*OP)(const float), + const bool rowwise, + const bool colwise, const InputType* input, const InputType* grad, - OutputType* output_c, - float* dbias, - fp8e8m0* output_scales, - const size_t scale_idx, - const size_t i_min, - const size_t i_max, - const size_t j_min, - const size_t j_max, - const size_t cols) { + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* output_scales_rowwise, + fp8e8m0* output_scales_colwise, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) +{ #ifdef __HIP_PLATFORM_AMD__ using std::isnan, std::isinf; #endif - float amax = 0.0f; - - // Find the absolute maximum value in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - const size_t idx = i * cols + j; - float elt = static_cast(input[idx]); - if (processing_method == ProcessingMethod::CAST_DBIAS) { - // grad is the input - elt = static_cast(grad[idx]); - } - if (processing_method != ProcessingMethod::CAST_ONLY - && processing_method != ProcessingMethod::CAST_DBIAS) { - elt = OP(elt); - } - if (processing_method == ProcessingMethod::CAST_DACT || - processing_method == ProcessingMethod::CAST_DBIAS_DACT) { - elt *= static_cast(grad[idx]); - } - dbias[j] += elt; - if (isinf(elt) || isnan(elt)) { - continue; - } - amax = std::max(amax, std::abs(elt)); - } - } - - const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_reciprocal()); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - output_scales[scale_idx] = biased_exponent; - - // Quantize elements in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - const size_t idx = i * cols + j; - float elt = static_cast(input[idx]); - if (processing_method == ProcessingMethod::CAST_DBIAS) { - // grad is the input - elt = static_cast(grad[idx]); - } - if (processing_method != ProcessingMethod::CAST_ONLY - && processing_method != ProcessingMethod::CAST_DBIAS) { - elt = OP(elt); - } - if (processing_method == ProcessingMethod::CAST_DACT || - processing_method == ProcessingMethod::CAST_DBIAS_DACT) { - elt *= static_cast(grad[idx]); - } - output_c[idx] = static_cast(elt * scale_reciprocal); - } - } -} - -template -void compute_ref_x1(const ProcessingMethod processing_method, - const InputType* input, - const InputType* grad, - OutputType* output_c, - fp8e8m0* output_scales, - InputType* output_dbias, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride) -{ - const size_t tile_size_Y = std::max(32lu, block_size_Y); - const size_t tile_size_X = std::max(64lu, block_size_X); + const size_t tile_size_Y = 32; + const size_t tile_size_X = 32; const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; - const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; - const size_t blocks_per_tile_X = tile_size_X / block_size_X; std::vector output_dbias_fp32(cols, 0); #pragma omp parallel proc_bind(spread) { + // Buffers to cache intermediate computations + std::vector cache_buffer(tile_size_Y * tile_size_X); + std::vector thread_dbias(cols, 0); #pragma omp for schedule(static) for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { @@ -138,24 +77,83 @@ void compute_ref_x1(const ProcessingMethod processing_method, const size_t tile_offset_Y = tile_Y * tile_size_Y; const size_t tile_offset_X = tile_X * tile_size_X; - for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { - const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; - const size_t block_offset_Y = ii * block_size_Y; - const size_t i_min = tile_offset_Y + block_offset_Y; - if (i_min >= rows) continue; - const size_t i_max = std::min(i_min + block_size_Y, rows); - - for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { - const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; - const size_t block_offset_X = jj * block_size_X; - const size_t j_min = tile_offset_X + block_offset_X; - if (j_min >= cols) continue; - const size_t j_max = std::min(j_min + block_size_X, cols); - - const size_t scale_idx = block_idx_Y * scales_stride + block_idx_X; - scale_block( - processing_method, input, grad, output_c, thread_dbias.data(), - output_scales, scale_idx, i_min, i_max, j_min, j_max, cols); + const size_t i_min = tile_offset_Y; + const size_t i_max = std::min(i_min + tile_size_Y, rows); + + const size_t j_min = tile_offset_X; + const size_t j_max = std::min(j_min + tile_size_X, cols); + + // Cache computations + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + + float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + if (processing_method == ProcessingMethod::CAST_DACT || + processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + elt *= static_cast(grad[idx]); + } + thread_dbias[j] += elt; + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + elt = static_cast(static_cast(elt)); + + cache_buffer[cache_idx] = elt; + if (isinf(elt) || isnan(elt)) { + continue; + } + } + } + + if (rowwise) { + for (size_t i = i_min; i < i_max; ++i) { + float block_amax = 0.0f; + + for (size_t j = j_min; j < j_max; ++j) { + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const size_t scale_idx = i * scales_stride_rowwise + tile_X; + output_scales_rowwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_rowwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } + } + } + if (colwise) { + for (size_t j = j_min; j < j_max; ++j) { + float block_amax = 0.0f; + + for (size_t i = i_min; i < i_max; ++i) { + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const size_t scale_idx = tile_Y * scales_stride_colwise + j; + output_scales_colwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t i = i_min; i < i_max; ++i) { + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_colwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } } } } @@ -171,29 +169,6 @@ void compute_ref_x1(const ProcessingMethod processing_method, } } -template -void compute_ref_x2(const ProcessingMethod processing_method, - const InputType* input, - const InputType* grad, - OutputType* output_rowwise, - OutputType* output_colwise, - fp8e8m0* scales_rowwise, - fp8e8m0* scales_colwise, - InputType* output_dbias, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride_rowwise, - const size_t scales_stride_colwise) { - compute_ref_x1( - processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias, - rows, cols, 1, block_size_X, scales_stride_rowwise); - compute_ref_x1( - processing_method, input, grad, output_colwise, scales_colwise, output_dbias, - rows, cols, block_size_Y, 1, scales_stride_colwise); -} - /** * Scaling along single dimension (either rows or columns) * Produces one set of output data and the corresponding data of the fused operation (dbias): @@ -202,8 +177,9 @@ void compute_ref_x2(const ProcessingMethod processing_method, * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x1(const ProcessingMethod processing_method, + float (*OP)(const float), const std::vector& shape, const bool rowwise, const bool colwise, @@ -266,28 +242,46 @@ void performTest_x1(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS_DACT: { - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output_c.data(), - output_dbias.data(), - workspace.data(), - 0); + auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu; + if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; } + else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; } + else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; } + else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; } + + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output_c.data(), - output_dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); break; } case ProcessingMethod::CAST_DACT: { - nvte_dgelu(grad.data(), input.data(), output_c.data(), 0); + auto nvte_dact = &nvte_dgelu; + if (OP == &dsilu) { nvte_dact = &nvte_dsilu; } + else if (OP == &drelu) { nvte_dact = &nvte_drelu; } + else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; } + else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; } + + nvte_dact(grad.data(), input.data(), output_c.data(), 0); break; } case ProcessingMethod::CAST_ACT: { - nvte_gelu(input.data(), output_c.data(), 0); + auto nvte_act = &nvte_gelu; + if (OP == &silu) { nvte_act = &nvte_silu; } + else if (OP == &relu) { nvte_act = &nvte_relu; } + else if (OP == &qgelu) { nvte_act = &nvte_qgelu; } + else if (OP == &srelu) { nvte_act = &nvte_srelu; } + + nvte_act(input.data(), output_c.data(), 0); break; } } @@ -296,47 +290,59 @@ void performTest_x1(const ProcessingMethod processing_method, auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - compute_ref_x1(processing_method, - input.rowwise_cpu_dptr(), - grad.rowwise_cpu_dptr(), - ref_output_c.get(), - ref_output_scales.get(), - ref_output_dbias.get(), - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride); - - -#ifdef __HIP_PLATFORM_AMD__ - if (processing_method != ProcessingMethod::CAST_ONLY) { - std::vector> mismatch_idx; - compare_e8m0_scaling_factors("scales", output_c, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, rowwise, mismatch_idx); - - if (mismatch_idx.size()) { - adjust_ref(mismatch_idx, ref_output_c.get(), unpadded_blocks_Y, unpadded_blocks_X, rows, cols, otype); - } - - auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); - } - else -#endif // #ifdef __HIP_PLATFORM_AMD__ - { - auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); + compute_ref(processing_method, + OP, + rowwise, + colwise, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c.get(), + ref_output_c.get(), + ref_output_scales.get(), + ref_output_scales.get(), + ref_output_dbias.get(), + rows, + cols, + scales_stride, + scales_stride); const uint8_t * const gpu_scales_ptr = rowwise ? output_c.rowwise_cpu_scale_inv_ptr() : output_c.columnwise_cpu_scale_inv_ptr(); + const size_t scale_diff_abs_tolerance = 0; +#ifdef __HIP_PLATFORM_AMD__ + const double abs_tolerable_mismatches_limit = 1.0; + const double rel_tolerable_mismatches_limit = 1.0e-4; +#else + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; +#endif + + std::vector mismatches_scales_indices; + size_t mismatches_scales = 0; compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride); - } + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales_indices, mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); - if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { +#ifdef __HIP_PLATFORM_AMD__ + if (::testing::Test::HasFatalFailure()) return; + adjust_ref_for_e8m0_scale_error("scales", mismatches_scales_indices, gpu_scales_ptr, + ref_output_scales.get(), scales_stride, rows, cols, rowwise, + ref_output_c.get(), otype); + mismatches_scales = 0; +#endif + + const size_t mismatches_elts = 32 * mismatches_scales; + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts); + + if (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT) + { auto [atol_dbias, rtol_dbias] = getTolerances(itype); if (itype == DType::kFloat32) { atol_dbias = 1e-4; @@ -355,8 +361,9 @@ void performTest_x1(const ProcessingMethod processing_method, * AND * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x2(const ProcessingMethod processing_method, + float (*OP)(const float), const std::vector& shape, const size_t block_size_rows, const size_t block_size_cols, @@ -424,28 +431,46 @@ void performTest_x2(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS_DACT: { - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output.data(), - output_dbias.data(), - workspace.data(), - 0); + auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu; + if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; } + else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; } + else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; } + else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; } + + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output.data(), - output_dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); break; } case ProcessingMethod::CAST_DACT: { - nvte_dgelu(grad.data(), input.data(), output.data(), 0); + auto nvte_dact = &nvte_dgelu; + if (OP == &dsilu) { nvte_dact = &nvte_dsilu; } + else if (OP == &drelu) { nvte_dact = &nvte_drelu; } + else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; } + else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; } + + nvte_dact(grad.data(), input.data(), output.data(), 0); break; } case ProcessingMethod::CAST_ACT: { - nvte_gelu(input.data(), output.data(), 0); + auto nvte_act = &nvte_gelu; + if (OP == &silu) { nvte_act = &nvte_silu; } + else if (OP == &relu) { nvte_act = &nvte_relu; } + else if (OP == &qgelu) { nvte_act = &nvte_qgelu; } + else if (OP == &srelu) { nvte_act = &nvte_srelu; } + + nvte_act(input.data(), output.data(), 0); break; } } @@ -454,55 +479,75 @@ void performTest_x2(const ProcessingMethod processing_method, auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - compute_ref_x2(processing_method, - input.rowwise_cpu_dptr(), - grad.rowwise_cpu_dptr(), - ref_output_c_rowwise.get(), - ref_output_c_colwise.get(), - ref_scales_rowwise.get(), - ref_scales_colwise.get(), - ref_output_dbias.get(), - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride_rowwise, - scales_stride_colwise); + compute_ref(processing_method, + OP, + true, + true, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c_rowwise.get(), + ref_output_c_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_output_dbias.get(), + rows, + cols, + scales_stride_rowwise, + scales_stride_colwise); + + const size_t scale_diff_abs_tolerance = 0; #ifdef __HIP_PLATFORM_AMD__ - if (processing_method != ProcessingMethod::CAST_ONLY) { - std::vector> mismatch_idx_r; - compare_e8m0_scaling_factors("scales_rowwise", output, ref_scales_rowwise.get(), - unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise, 0.01, true, mismatch_idx_r); - - if (mismatch_idx_r.size()) { - adjust_ref(mismatch_idx_r, ref_output_c_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, rows, cols, otype); - } - std::vector> mismatch_idx_c; - compare_e8m0_scaling_factors("scales_colwise", output, ref_scales_colwise.get(), - unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise, 0.01, false, mismatch_idx_c); - - if (mismatch_idx_c.size()) { - adjust_ref(mismatch_idx_c, ref_output_c_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, rows, cols, otype); - } + const double abs_tolerable_mismatches_limit = 1.0; + const double rel_tolerable_mismatches_limit = 1.0e-4; +#else + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; +#endif - auto [atol, rtol] = getTolerances(otype); - compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); - compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); - } else -#endif // #ifdef __HIP_PLATFORM_AMD__ - { - auto [atol, rtol] = getTolerances(otype); - compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); - compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); + std::vector mismatches_scales_indices_rowwise; + size_t mismatches_scales_rowwise = 0; compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise); + unpadded_blocks_X_rowwise, scales_stride_rowwise, + mismatches_scales_indices_rowwise, mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + std::vector mismatches_scales_indices_colwise; + size_t mismatches_scales_colwise = 0; compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise); - } + unpadded_blocks_X_colwise, scales_stride_colwise, + mismatches_scales_indices_colwise, mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + +#ifdef __HIP_PLATFORM_AMD__ + if (::testing::Test::HasFatalFailure()) return; + adjust_ref_for_e8m0_scale_error("scales_rowwise", mismatches_scales_indices_rowwise, + output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), scales_stride_rowwise, rows, cols, + true, ref_output_c_rowwise.get(), otype); + mismatches_scales_rowwise = 0; + adjust_ref_for_e8m0_scale_error("scales_colwise", mismatches_scales_indices_colwise, + output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), scales_stride_colwise, rows, cols, + false, ref_output_c_colwise.get(), otype); + mismatches_scales_colwise = 0; +#endif + + const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; + const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; - if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise); + compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise); + + if (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT) + { auto [atol_dbias, rtol_dbias] = getTolerances(itype); if (itype == DType::kFloat32) { atol_dbias = 1e-4; @@ -521,11 +566,10 @@ std::vector> matrix_sizes = { {128, 128}, {256, 256}, {993, 512}, - {256, 65536}, - {2048, 6144}, - {16384, 128}, - {32768, 160}, - {4096, 1632}, + {511, 6144}, + {8192, 128}, + {2048, 160}, + {577, 1632}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, @@ -574,26 +618,6 @@ class FusedCastMXFP8TestSuite : public ::testing::TestWithParam transformer_engine::DType, InputsFillCase>> {}; -#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ -switch (OP_FUNC_TYPE) { \ - case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ - case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \ - case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \ - case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \ - case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \ - case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \ -} - -#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ -switch (OP_FUNC_TYPE) { \ - case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ - case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \ - case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \ - case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \ - case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \ - case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \ -} - TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { #ifndef __HIP_PLATFORM_AMD__ // Skip tests for pre-Blackwell architectures @@ -629,35 +653,48 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { const bool colwise = block_size.first != 1; if (processing_method == ProcessingMethod::CAST_ACT) { // Forward activations - ACT_FUNC_SWITCH(Act_type, OP, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, - if (block_size.first == 1 || block_size.second == 1) { - performTest_x1( - processing_method, matrix_size, - rowwise, colwise, fill_case); - } else { - performTest_x2( - processing_method, matrix_size, - block_size.first, block_size.second, fill_case); - } - ); + auto OP = &identity; + switch (Act_type) { + case ActivationType::GeLU: OP = &gelu; break; + case ActivationType::SiLU: OP = &silu; break; + case ActivationType::ReLU: OP = &relu; break; + case ActivationType::QGeLU: OP = &qgelu; break; + case ActivationType::SReLU: OP = &srelu; break; + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, OP, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, OP, matrix_size, + block_size.first, block_size.second, fill_case); + } ); ); } else { - DACT_FUNC_SWITCH(Act_type, OP, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, - if (block_size.first == 1 || block_size.second == 1) { - performTest_x1( - processing_method, matrix_size, - rowwise, colwise, fill_case); - } else { - performTest_x2( - processing_method, matrix_size, - block_size.first, block_size.second, fill_case); - } - ); + auto OP = &identity; + switch (Act_type) { + case ActivationType::GeLU: OP = &dgelu; break; + case ActivationType::SiLU: OP = &dsilu; break; + case ActivationType::ReLU: OP = &drelu; break; + case ActivationType::QGeLU: OP = &dqgelu; break; + case ActivationType::SReLU: OP = &dsrelu; break; + } + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, OP, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, OP, matrix_size, + block_size.first, block_size.second, fill_case); + } ); ); } diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 96663e752..52180786d 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -24,108 +24,32 @@ using namespace test; namespace { -template -void scale_block(const IType* grad, +template +void compute_ref(const IType* grad, const IType* input, - OType* output, - fp8e8m0* output_scales, - const size_t scale_idx, - const size_t scale_idx_gate, - float& thread_amax, - const size_t i_min, - const size_t i_max, - const size_t j_min, - const size_t j_max, - const size_t cols) { - - float block_amax = 0.0f; - float block_amax_gate = 0.0f; - const size_t stride = cols * 2; - - - // Find the absolute maximum value in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - float silu_elt = static_cast(input[i * stride + j]); - float gate_elt = static_cast(input[i * stride + cols + j]); - float gated_amax_act = 0; - float gated_amax_gate = 0; - - if constexpr (IS_DGATED) { - const float grad_elt = static_cast(grad[i * cols + j]); - const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; - const float after_dgate = silu(silu_elt) * grad_elt; - gated_amax_act = abs(after_dsilu); - gated_amax_gate = abs(after_dgate); - } else { - const float after_silu = silu(silu_elt) * gate_elt; - gated_amax_act = abs(after_silu); - } - - if (gated_amax_act > block_amax) { block_amax = gated_amax_act; } - if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; } - } - } - - const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * - Quantized_Limits::max_reciprocal()); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - output_scales[scale_idx] = biased_exponent; - float scale_reciprocal_gate = 1; - if constexpr (IS_DGATED) { - const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate * - Quantized_Limits::max_reciprocal()); - scale_reciprocal_gate = exp2f_rcp(biased_exponent); - output_scales[scale_idx_gate] = biased_exponent; - } - - - // Quantize elements in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - float silu_elt = static_cast(input[i * stride + j]); - float gate_elt = static_cast(input[i * stride + cols + j]); - - if constexpr (IS_DGATED) { - const float grad_elt = static_cast(grad[i * cols + j]); - const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; - const float after_dgate = silu(silu_elt) * grad_elt; - output[i * stride + j] = static_cast(after_dsilu * scale_reciprocal); - output[i * stride + cols + j] = static_cast(after_dgate * - scale_reciprocal_gate); - } else { - const float after_silu = silu(silu_elt) * gate_elt; - output[i * cols + j] = static_cast(after_silu * scale_reciprocal); - } - - } - } - thread_amax = std::max(thread_amax, block_amax); - thread_amax = std::max(thread_amax, block_amax_gate); -} - -template -void compute_ref_x1(const IType* grad, - const IType* input, - OType* output, - fp8e8m0* output_scales, - float& ref_amax, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride) { - const size_t tile_size_Y = std::max(32lu, block_size_Y); - const size_t tile_size_X = std::max(64lu, block_size_X); + OType* output_rowwise, + OType* output_colwise, + fp8e8m0* output_scales_rowwise, + fp8e8m0* output_scales_colwise, + float& ref_amax, + const bool IS_DGATED, + const size_t rows, + const size_t cols, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise, + const bool is_rowwise, + const bool is_colwise) { + constexpr size_t tile_size_Y = 32; + constexpr size_t tile_size_X = 32; const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; - const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; - const size_t blocks_per_tile_X = tile_size_X / block_size_X; - float amax = 0; #pragma omp parallel reduction(max: amax) proc_bind(spread) { - float thread_amax = 0; + // Buffers to cache intermediate computations + std::vector cache_buffer_act(tile_size_Y * tile_size_X); + std::vector cache_buffer_gate(tile_size_Y * tile_size_X); + float thread_amax = 0.0f; #pragma omp for schedule(static) for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { const size_t tile_Y = t / tiles_num_X; @@ -133,26 +57,124 @@ void compute_ref_x1(const IType* grad, const size_t tile_offset_Y = tile_Y * tile_size_Y; const size_t tile_offset_X = tile_X * tile_size_X; - for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { - const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; - const size_t block_offset_Y = ii * block_size_Y; - const size_t i_min = tile_offset_Y + block_offset_Y; - if (i_min >= rows) continue; - const size_t i_max = std::min(i_min + block_size_Y, rows); - - for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { - const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; - const size_t block_offset_X = jj * block_size_X; - const size_t j_min = tile_offset_X + block_offset_X; - if (j_min >= cols) continue; - const size_t j_max = std::min(j_min + block_size_X, cols); - - const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X; - const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X + - cols / block_size_X; - scale_block( - grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate, - thread_amax, i_min, i_max, j_min, j_max, cols); + const size_t stride = cols * 2; + + const size_t i_min = tile_offset_Y; + const size_t i_max = std::min(rows, tile_offset_Y + tile_size_Y); + const size_t j_min = tile_offset_X; + const size_t j_max = std::min(cols, tile_offset_X + tile_size_X); + + // Compute and cache activations for the entire tile + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + float silu_elt = static_cast(input[i * stride + j]); + float gate_elt = static_cast(input[i * stride + cols + j]); + + const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min); + + if (IS_DGATED) { + const float x = silu_elt; + const float s = sigmoid(x); + const float act_x = x * s; + const float dact_x = x * s * (1 - s) + s; + + const float grad_elt = static_cast(grad[i * cols + j]); + float after_dsilu = dact_x * grad_elt * gate_elt; + float after_dgate = act_x * grad_elt; + + // Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32 + after_dsilu = static_cast(static_cast(after_dsilu)); + after_dgate = static_cast(static_cast(after_dgate)); + + cache_buffer_act[cached_idx] = after_dsilu; + cache_buffer_gate[cached_idx] = after_dgate; + thread_amax = std::max(thread_amax, std::abs(after_dsilu)); + thread_amax = std::max(thread_amax, std::abs(after_dgate)); + } else { + float after_silu = silu(silu_elt) * gate_elt; + + // Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32 + after_silu = static_cast(static_cast(after_silu)); + + cache_buffer_act[cached_idx] = after_silu; + thread_amax = std::max(thread_amax, std::abs(after_silu)); + } + } + } + + if (is_rowwise) { + for (size_t i = i_min; i < i_max; ++i) { + float block_amax_act = 0.0f; + float block_amax_gate = 0.0f; + for (size_t j = j_min; j < j_max; ++j) { + const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); + if (IS_DGATED) { + block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); + } + } + const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits::max_reciprocal()); + const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); + const size_t scale_idx_act = i * scales_stride_rowwise + tile_X; + output_scales_rowwise[scale_idx_act] = biased_exponent_act; + + float scale_reciprocal_gate; + if (IS_DGATED) { + const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits::max_reciprocal()); + scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); + const size_t scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32; + output_scales_rowwise[scale_idx_gate] = biased_exponent_gate; + } + for (size_t j = j_min; j < j_max; ++j) { + const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min); + const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; + + if (IS_DGATED) { + const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate; + output_rowwise[i * stride + j] = static_cast(after_act); + output_rowwise[i * stride + cols + j] = static_cast(after_gate); + } else { + output_rowwise[i * cols + j] = static_cast(after_act); + } + } + } + } + + if (is_colwise) { + for (size_t j = j_min; j < j_max; ++j) { + float block_amax_act = 0.0f; + float block_amax_gate = 0.0f; + for (size_t i = i_min; i < i_max; ++i) { + const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); + if (IS_DGATED) { + block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); + } + } + const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits::max_reciprocal()); + const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); + const size_t scale_idx_act = tile_Y * scales_stride_colwise + j; + output_scales_colwise[scale_idx_act] = biased_exponent_act; + + float scale_reciprocal_gate; + if (IS_DGATED) { + const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits::max_reciprocal()); + const size_t scale_idx_gate = scale_idx_act + cols; + scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); + output_scales_colwise[scale_idx_gate] = biased_exponent_gate; + } + for (size_t i = i_min; i < i_max; ++i) { + const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min); + const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; + + if (IS_DGATED) { + const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate; + output_colwise[i * stride + j] = static_cast(after_act); + output_colwise[i * stride + cols + j] = static_cast(after_gate); + } else { + output_colwise[i * cols + j] = static_cast(after_act); + } + } } } } @@ -163,26 +185,6 @@ void compute_ref_x1(const IType* grad, ref_amax = amax; } -template -void compute_ref_x2(const IType* grad, - const IType* input, - OType* output_rowwise, - OType* output_colwise, - fp8e8m0* scales_rowwise, - fp8e8m0* scales_colwise, - float& ref_amax, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride_rowwise, - const size_t scales_stride_colwise) { - compute_ref_x1( - grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise); - compute_ref_x1( - grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise); -} - /** * Scaling along single dimension (either rows or columns) * Produces one set of output data and the corresponding data of the fused operation (dbias): @@ -190,12 +192,13 @@ void compute_ref_x2(const IType* grad, * OR * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x1(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols, - InputsFillCase fill_case) { + InputsFillCase fill_case, + const bool IS_DGATED) { using namespace test; using EncodingType = fp32; DType itype = TypeInfo::dtype; @@ -205,12 +208,6 @@ void performTest_x1(const size_t rows, const bool colwise = (block_size_rows == 32) && (block_size_cols == 1); NVTE_CHECK(rowwise || colwise); - // std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl; - // std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl; - // std::cout << "blocks_Y: " << blocks_Y << std::endl; - // std::cout << "blocks_X: " << blocks_X << std::endl; - // std::cout << "scales_stride: " << scales_stride << std::endl; - Tensor grad("grad", std::vector{ rows, cols }, itype); Tensor input("input", std::vector{ rows, cols * 2 }, itype); @@ -236,12 +233,12 @@ void performTest_x1(const size_t rows, } // fillCase(&grad, fill_case); - if constexpr (IS_DGATED) { + if (IS_DGATED) { fillUniform(&grad); } fillUniform(&input); - if constexpr (IS_DGATED) { + if (IS_DGATED) { nvte_dswiglu(grad.data(), input.data(), output.data(), 0); } else { nvte_swiglu(input.data(), output.data(), 0); @@ -252,46 +249,59 @@ void performTest_x1(const size_t rows, ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); float ref_amax = 0; - compute_ref_x1(grad.rowwise_cpu_dptr(), - input.rowwise_cpu_dptr(), - ref_output.get(), - ref_output_scales.get(), - ref_amax, - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride); -#ifdef __HIP_PLATFORM_AMD__ - std::vector> mismatch_idx; - if (rowwise) { - compare_e8m0_scaling_factors("rowwise scales", output, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, true, mismatch_idx); - } else { - compare_e8m0_scaling_factors("colwise scales", output, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, false, mismatch_idx); - } - if (mismatch_idx.size()) { - adjust_ref(mismatch_idx, ref_output.get(), unpadded_blocks_Y, unpadded_blocks_X, rows, cols, otype); - } - - auto [atol, rtol] = getTolerances(otype); - compareResults("output", output, ref_output.get(), rowwise, atol, rtol); -#else // #ifdef __HIP_PLATFORM_AMD__ - auto [atol, rtol] = getTolerances(otype); - compareResults("output", output, ref_output.get(), rowwise, atol, rtol); + compute_ref(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output.get(), + ref_output_scales.get(), + ref_output_scales.get(), + ref_amax, + IS_DGATED, + rows, + cols, + scales_stride, + scales_stride, + rowwise, + colwise); + + std::vector mismatches_scales_indices; + size_t mismatches_scales = 0; + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 1.0; + const double rel_tolerable_mismatches_limit = 1.0e-4; const uint8_t * const gpu_scales_ptr = rowwise ? output.rowwise_cpu_scale_inv_ptr() : output.columnwise_cpu_scale_inv_ptr(); if (rowwise) { compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales_indices, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); } else { compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales_indices, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); } -#endif // #ifdef __HIP_PLATFORM_AMD__ + +#ifdef __HIP_PLATFORM_AMD__ + if (::testing::Test::HasFatalFailure()) return; + adjust_ref_for_e8m0_scale_error("scales", mismatches_scales_indices, gpu_scales_ptr, + ref_output_scales.get(), scales_stride, rows, output_cols, + rowwise, ref_output.get(), otype); + mismatches_scales = 0; +#endif + + const size_t mismatches_elts = 32 * mismatches_scales; + auto [atol, rtol] = getTolerances(otype); + compareResults("output", output, ref_output.get(), rowwise, atol, rtol, true, mismatches_elts); } /** @@ -301,12 +311,13 @@ void performTest_x1(const size_t rows, * AND * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x2(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols, - InputsFillCase fill_case) { + InputsFillCase fill_case, + const bool IS_DGATED) { using namespace test; using EncodingType = fp32; DType itype = TypeInfo::dtype; @@ -348,12 +359,12 @@ void performTest_x2(const size_t rows, } // fillCase(&grad, fill_case); - if constexpr (IS_DGATED) { + if (IS_DGATED) { fillUniform(&grad); } fillUniform(&input); - if constexpr (IS_DGATED) { + if (IS_DGATED) { nvte_dswiglu(grad.data(), input.data(), output.data(), 0); } else { nvte_swiglu(input.data(), output.data(), 0); @@ -364,54 +375,65 @@ void performTest_x2(const size_t rows, ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); float ref_amax = 0; - compute_ref_x2(grad.rowwise_cpu_dptr(), - input.rowwise_cpu_dptr(), - ref_output_rowwise.get(), - ref_output_colwise.get(), - ref_scales_rowwise.get(), - ref_scales_colwise.get(), - ref_amax, - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride_rowwise, - scales_stride_colwise); -#ifdef __HIP_PLATFORM_AMD__ - std::vector> mismatch_idx_r; - compare_e8m0_scaling_factors("scales_rowwise", output, + compute_ref(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output_rowwise.get(), + ref_output_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_amax, + IS_DGATED, + rows, + cols, + scales_stride_rowwise, + scales_stride_colwise, + true, + true); + + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 1.0; + const double rel_tolerable_mismatches_limit = 1.0e-4; + + std::vector mismatches_scales_indices_rowwise; + size_t mismatches_scales_rowwise = 0; + compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, 0.01, true, mismatch_idx_r); - - if (mismatch_idx_r.size()) { - adjust_ref(mismatch_idx_r, ref_output_colwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, rows, cols, otype); - } - - std::vector> mismatch_idx_c; - compare_e8m0_scaling_factors("scales_colwise", output, + unpadded_blocks_X_rowwise, scales_stride_rowwise, + mismatches_scales_indices_rowwise, mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + std::vector mismatches_scales_indices_colwise; + size_t mismatches_scales_colwise = 0; + compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, 0.01, false, mismatch_idx_c); + unpadded_blocks_X_colwise, scales_stride_colwise, + mismatches_scales_indices_colwise, mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); - if (mismatch_idx_c.size()) { - adjust_ref(mismatch_idx_c, ref_output_rowwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, rows, cols, otype); - } +#ifdef __HIP_PLATFORM_AMD__ + if (::testing::Test::HasFatalFailure()) return; + adjust_ref_for_e8m0_scale_error("scales_rowwise", mismatches_scales_indices_rowwise, + output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), scales_stride_rowwise, rows, + output_cols, true, ref_output_rowwise.get(), otype); + mismatches_scales_rowwise = 0; + adjust_ref_for_e8m0_scale_error("scales_colwise", mismatches_scales_indices_colwise, + output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), scales_stride_colwise, rows, + output_cols, false, ref_output_colwise.get(), otype); + mismatches_scales_colwise = 0; +#endif + + const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; + const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; auto [atol, rtol] = getTolerances(otype); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); - compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol); -#else // #ifdef __HIP_PLATFORM_AMD__ - auto [atol, rtol] = getTolerances(otype); - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); - compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol); - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise); - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise); -#endif // #ifdef __HIP_PLATFORM_AMD__ + compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise); + compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise); } std::vector> matrix_sizes = { @@ -422,8 +444,8 @@ std::vector> matrix_sizes = { {256, 256}, {993, 512}, {768, 1024}, - {65504, 128}, - {16384, 1632}, + {8192, 128}, + {577, 1632}, }; std::vector> block_sizes = { @@ -440,9 +462,9 @@ std::vector input_scenarios = { // InputsFillCase::maxNorm_to_inf }; -std::vector is_dgated_op = { - true, - false +std::vector is_bwd_op = { + false, + true }; } // namespace @@ -479,21 +501,11 @@ TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType, if (block_size.first == 1 || block_size.second == 1) { - if (IS_DGATED) { - performTest_x1(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } else { - performTest_x1(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } + performTest_x1(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case, IS_DGATED); } else { - if (IS_DGATED) { - performTest_x2(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } else { - performTest_x2(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } + performTest_x2(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case, IS_DGATED); } ); ); @@ -508,7 +520,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), - ::testing::ValuesIn(is_dgated_op)), + ::testing::ValuesIn(is_bwd_op)), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + @@ -517,6 +529,6 @@ INSTANTIATE_TEST_SUITE_P( test::typeName(std::get<2>(info.param)) + "X" + test::typeName(std::get<3>(info.param)) + "X" + test::caseName(std::get<4>(info.param)) + "X" + - (std::get<5>(info.param) ? "DGATED" : "GATED"); + (std::get<5>(info.param) ? "BWD" : "FWD"); return name; }); diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index 1896cd329..f2e5c9499 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -29,15 +29,24 @@ namespace { template void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, - NormType norm_type, bool use_cudnn, const bool zero_centered_gamma_in_weight_dtype) { + const NormType norm_type, const bool use_cudnn, + const bool zero_centered_gamma_in_weight_dtype, const bool fused_bwd_add) { if (sizeof(InputType) < sizeof(OutputType)) { GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; - return; + } + + if (norm_type == LayerNorm && fused_bwd_add) { + GTEST_SKIP() << "Fused LN backward+add not currently supported"; + } + + if (fused_bwd_add && zero_centered_gamma_in_weight_dtype) { + GTEST_SKIP() << "zero_centered_gamma_in_weight_dtype not currently supported " + << "in fused norm backward+add"; } #ifndef __HIP_PLATFORM_AMD__ - if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) { - GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!"; + if (getDeviceComputeCapability() < hopperComputeCapability && use_cudnn) { + GTEST_SKIP() << "cuDNN normalizations not supported on pre-Hopper GPUs yet!"; } #endif @@ -49,7 +58,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, if ((itype == DType::kBFloat16 && otype == DType::kFloat16) || (itype == DType::kFloat16 && otype == DType::kBFloat16)) { GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16"; - return; } Tensor input("input", std::vector{ N, H }, itype); @@ -59,6 +67,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, Tensor mu("mu", std::vector{ N }, DType::kFloat32); Tensor rsigma("rsigma", std::vector{ N }, DType::kFloat32); Tensor dz("dz", std::vector{ N, H }, wtype); + Tensor bwd_add("bwd_add", std::vector{ N, H }, wtype); Tensor dx("dx", std::vector{ N, H }, itype); Tensor dgamma("dgamma", std::vector{ H }, wtype); Tensor dbeta("dbeta", std::vector{ H }, wtype); @@ -69,6 +78,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, fillUniform(&beta); setRandomScale(&z); fillUniform(&dz); + if (fused_bwd_add) { + fillUniform(&bwd_add); + } else { + fillCase(&bwd_add, zeros); + } std::unique_ptr ref_output = std::make_unique(N * H); std::unique_ptr ref_mu = std::make_unique(N); @@ -94,7 +108,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, nvte_enable_cudnn_norm_fwd(true); nvte_enable_cudnn_norm_bwd(true); - // Zero-centered gamma in weight dtype only supported by CuDNN backend currently if (zero_centered_gamma_in_weight_dtype) { nvte_enable_zero_centered_gamma_in_weight_dtype(true); @@ -135,15 +148,23 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, z.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), - workspace_bwd.data(), - prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); - nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), - workspace_bwd.data(), - prop.multiProcessorCount, zero_centered_gamma, 0); + if (fused_bwd_add) { + nvte_rmsnorm_bwd_add(dz.data(), input.data(), bwd_add.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); + nvte_rmsnorm_bwd_add(dz.data(), input.data(), bwd_add.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + } else { + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount, + zero_centered_gamma, 0); + workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount, + zero_centered_gamma, 0); + } } #ifndef __HIP_PLATFORM_AMD__ @@ -179,6 +200,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, use_cudnn, zero_centered_gamma_in_weight_dtype); compute_ref_backward(norm_type, dz.rowwise_cpu_dptr(), + bwd_add.rowwise_cpu_dptr(), input.rowwise_cpu_dptr(), mu.rowwise_cpu_dptr(), rsigma.rowwise_cpu_dptr(), gamma.rowwise_cpu_dptr(), @@ -226,30 +248,40 @@ std::vector> test_cases = { } // namespace class NormTestSuite : public ::testing::TestWithParam, - bool, - bool>> {}; + NormType, + transformer_engine::DType, + transformer_engine::DType, + std::pair, + bool, + bool, + bool>> {}; TEST_P(NormTestSuite, TestNorm) { - using namespace transformer_engine; - using namespace test; + using namespace transformer_engine; + using namespace test; const bool use_cudnn = std::get<0>(GetParam()); const NormType norm_type = std::get<1>(GetParam()); - const DType input_type = std::get<2>(GetParam()); - const DType output_type = std::get<3>(GetParam()); - const auto size = std::get<4>(GetParam()); - const bool zero_centered_gamma = std::get<5>(GetParam()); - const bool cudnn_zero_centered_gamm_in_weight_dtype = std::get<6>(GetParam()); - - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTest(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn, cudnn_zero_centered_gamm_in_weight_dtype); + const DType input_type = std::get<2>(GetParam()); + const DType output_type = std::get<3>(GetParam()); + const auto size = std::get<4>(GetParam()); + const bool zero_centered_gamma = std::get<5>(GetParam()); + const bool cudnn_zero_centered_gamma_in_weight_dtype = std::get<6>(GetParam()); + const bool fused_bwd_add = std::get<7>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest( + size.first, + size.second, + zero_centered_gamma, + norm_type, + use_cudnn, + cudnn_zero_centered_gamma_in_weight_dtype, + fused_bwd_add ); ); + ); } INSTANTIATE_TEST_SUITE_P( @@ -267,10 +299,11 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(test_cases), ::testing::Values(false, true), #ifdef __HIP_PLATFORM_AMD__ - ::testing::Values(false)), // HIP does not use cudnn_zero_centered_gamm_in_weight_dtype + ::testing::Values(false), // HIP does not use cudnn_zero_centered_gamm_in_weight_dtype #else - ::testing::Values(true, false)), + ::testing::Values(false, true), #endif + ::testing::Values(false, true)), [](const testing::TestParamInfo& info) { auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn"; std::string name = @@ -281,6 +314,7 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(std::get<4>(info.param).first) + "X" + std::to_string(std::get<4>(info.param).second) + "X" + std::to_string(std::get<5>(info.param)) + "X" + - std::to_string(std::get<6>(info.param)); + std::to_string(std::get<6>(info.param)) + "X" + + std::to_string(std::get<7>(info.param)); return name; }); diff --git a/tests/cpp/operator/test_normalization.h b/tests/cpp/operator/test_normalization.h index ddb9838b2..674e09c8e 100644 --- a/tests/cpp/operator/test_normalization.h +++ b/tests/cpp/operator/test_normalization.h @@ -135,7 +135,8 @@ void compute_ref_output(NormType norm_type, template -void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data, +void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, + const OutputType *add, const InputType *data, const float *mu, const float *rsigma, const InputType *gamma, InputType *data_grad, @@ -174,7 +175,8 @@ void compute_ref_backward(const NormType norm_type, const OutputType *output_gra compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype); const compute_t dz = static_cast(output_grad[i * H + j]); const compute_t dy = g * dz; - const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy); + const compute_t a = static_cast(add[i * H + j]); + const compute_t dx = a + rsigma[i] * (dy - mdyy * y - mdy); data_grad[i * H + j] = static_cast(dx); } } diff --git a/tests/cpp/operator/test_swap_first_dims.cu b/tests/cpp/operator/test_swap_first_dims.cu new file mode 100644 index 000000000..4c2cf415f --- /dev/null +++ b/tests/cpp/operator/test_swap_first_dims.cu @@ -0,0 +1,112 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const Type *input, Type *output, + const std::vector &shape) { + const size_t dim0 = shape[0]; + const size_t dim1 = shape[1]; + size_t dim2 = 1; + for (size_t i = 2; i < shape.size(); ++i) { + dim2 *= shape[i]; + } + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + for (size_t k = 0; k < dim2; ++k) { + const size_t in_offset = i * dim1 * dim2 + j * dim2 + k; + const size_t out_offset = j * dim0 * dim2 + i * dim2 + k; + output[out_offset] = input[in_offset]; + } + } + } +} + +template +void performTest(const std::vector &in_shape) { + using namespace test; + + DType dtype = TypeInfo::dtype; + + // Tensor dimensions + std::vector out_shape = in_shape; + out_shape[0] = in_shape[1]; + out_shape[1] = in_shape[0]; + size_t numel = 1; + for (const auto& dim : in_shape) { + numel *= dim; + } + + // Transformer engine implementation + Tensor input("input", in_shape, dtype); + Tensor output("output", out_shape, dtype); + fillUniform(&input); + nvte_swap_first_dims(input.data(), output.data(), 0); + + // Reference implementation + std::unique_ptr ref_output = std::make_unique(numel); + compute_ref(input.rowwise_cpu_dptr(), ref_output.get(), in_shape); + + // Check for CUDA failure + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // Check for exact numerics + compareResults("output", output, ref_output.get(), true, 0, 0); +} + +std::vector> test_cases = {{4, 64, 1280}, + {48, 8, 128, 16}, + {229, 173}, // Primes 50, 40 + {113, 71, 1, 1, 1, 29, 1, 1}}; // Primes 30, 20, 10 +} // namespace + +class SwapFirstDimsTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(SwapFirstDimsTestSuite, TestSwapFirstDims) { + using namespace transformer_engine; + using namespace test; + + const DType type = std::get<0>(GetParam()); + const auto shape = std::get<1>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, + performTest(shape); + ); +} + + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwapFirstDimsTestSuite, + ::testing::Combine( + ::testing::ValuesIn(test::all_fp_types), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)); + for (const auto& dim : std::get<1>(info.param)) { + name += "X"; + name += std::to_string(dim); + } + return name; + }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index a608f6ef2..9f926d07b 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -528,21 +528,19 @@ std::vector unravel(const size_t i, const NVTEShape &shape) { void compareResults_sequential(const std::string &name, const Tensor &test, const void *ref, const bool rowwise, - double atol, double rtol, bool if_on_gpus) { + double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { if (if_on_gpus) test.to_cpu(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const size_t N = product(shape); + size_t mismatches_num = 0; + int first_mismatch_idx = -1; TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); const T *ref_data = reinterpret_cast(ref); for (size_t i = 0; i < N; ++i) { -#ifndef __HIP_PLATFORM_AMD__ double t = static_cast(test_data[i]); double r = static_cast(ref_data[i]); -#else - double t = static_cast(static_cast(test_data[i])); - double r = static_cast(static_cast(ref_data[i])); -#endif bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); /* For Float32 the floating point comparison is enough to error out */ bool assertion = mismatch && test.dtype() == DType::kFloat32; @@ -557,85 +555,102 @@ void compareResults_sequential(const std::string &name, const Tensor &test, assertion = !(cast_mean_m == std::min(t, r) && cast_mean_p == std::max(t, r)); } std::string direction = rowwise ? "rowwise" : "columnwise"; - ASSERT_FALSE(assertion) << "Error in tensor " << name << " in " - << direction << " direction." << std::endl - << "Mismatch at place " << to_string(unravel(i, shape)) - << " (" << std::to_string(i) << "): " << t << " vs " << r; + if (assertion) { + mismatches_num++; + if (first_mismatch_idx == -1) { + first_mismatch_idx = i; + } + } + if (mismatches_num > tolerable_mismatches_limit) { + const double first_mismatch_t = static_cast(test_data[first_mismatch_idx]); + const double first_mismatch_r = static_cast(ref_data[first_mismatch_idx]); + + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape)) + << " (" << std::to_string(first_mismatch_idx) << "): " + << first_mismatch_t << " vs " << first_mismatch_r; + } } ); } template static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data, - const size_t N, const double atol, const double rtol) { + const size_t N, const double atol, const double rtol, + size_t& mismatches) { int first_mismatch_idx = N; - bool is_mismatch_found = false; - #pragma omp parallel for schedule(static) firstprivate(is_mismatch_found) \ - reduction(min: first_mismatch_idx) proc_bind(spread) - for (size_t i = 0; i < N; ++i) { - if (is_mismatch_found) { // early escape of the omp thread - continue; - } + #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread) + { + size_t thread_mismatches = 0; + #pragma omp for schedule(static) + for (size_t i = 0; i < N; ++i) { + double t = static_cast(test_data[i]); + double r = static_cast(ref_data[i]); -#ifndef __HIP_PLATFORM_AMD__ - double t = static_cast(test_data[i]); - double r = static_cast(ref_data[i]); -#else - double t = static_cast(static_cast(test_data[i])); - double r = static_cast(static_cast(ref_data[i])); -#endif - - bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); - /* For Float32 the floating point comparison is enough to error out */ - bool assertion = mismatch && (data_type == DType::kFloat32); - if (mismatch && !assertion) { - /* Check if it is just a failure of round to nearest choosing different - side of the real value */ - const double mean = (t + r) / 2; - const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); - const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); - const double cast_mean_p = static_cast(static_cast(mean_p)); - const double cast_mean_m = static_cast(static_cast(mean_m)); - assertion = !(cast_mean_m == std::min(t, r) && cast_mean_p == std::max(t, r)); - } - if (assertion && i < first_mismatch_idx) { - first_mismatch_idx = i; - is_mismatch_found = true; + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = mismatch && (data_type == DType::kFloat32); + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + if (assertion) { + if (i < first_mismatch_idx) { + first_mismatch_idx = i; + } + thread_mismatches++; + } } + mismatches += thread_mismatches; } return first_mismatch_idx; } void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref, - const bool rowwise, double atol, double rtol, bool if_on_gpus) { + const bool rowwise, double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { if (if_on_gpus) test.to_cpu(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const size_t N = product(shape); + size_t mismatches = 0; TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); const T *ref_data = reinterpret_cast(ref); - const size_t i = getFirstMismatchIdx(test.dtype(), test_data, ref_data, N, atol, rtol); - if (i != N) { + const size_t i = getFirstMismatchIdx(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches); + if ((i != N) && (mismatches > tolerable_mismatches_limit)) { const double t = static_cast(test_data[i]); const double r = static_cast(ref_data[i]); std::string direction = rowwise ? "rowwise" : "columnwise"; - ASSERT_FALSE(true) << "Error in tensor " << name << " in " - << direction << " direction." << std::endl - << "Mismatch at place " << to_string(unravel(i, shape)) - << " (" << std::to_string(i) << "): " << t << " vs " << r; + + GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "Mismatch at place " << to_string(unravel(i, shape)) + << " (" << std::to_string(i) << "): " << t << " vs " << r; } ); } void compareResults(const std::string &name, const Tensor &test, const void *ref, - const bool rowwise, double atol, double rtol, bool if_on_gpus) { + const bool rowwise, double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { constexpr bool sequential = false; if constexpr (sequential) { - compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus); + compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit); } else { - compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus); + compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit); } } @@ -672,93 +687,84 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t } void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride) + const size_t row_blocks, const size_t col_blocks, const size_t stride, + std::vector &mismatch_indices, + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit) { + const size_t N = row_blocks * col_blocks; + const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, + std::floor(N * rel_tolerable_mismatches_limit)); + mismatches_num = 0; + for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { const int idx = i * stride + j; - ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl - << "Mismatch: " << static_cast(test[idx]) << " vs " - << static_cast(ref[idx]) << " at index " << idx; - } - } -} + const int test_val = static_cast(test[idx]); + const int ref_val = static_cast(ref[idx]); + const int abs_delta = std::abs(test_val - ref_val); -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t N) -{ - for (int i = 0; i < N; i++) { - ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl - << "Mismatch: " << static_cast(test[i]) << " vs " - << static_cast(ref[i]) << " at index " << i; - } -} - -#ifdef __HIP_PLATFORM_AMD__ -void compare_e8m0_scaling_factors(const std::string &name, Tensor &output, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - double tol, bool rowwise, std::vector> &mismatch_idx) { - const uint8_t *const test = rowwise ? output.rowwise_cpu_scale_inv_ptr() - : output.columnwise_cpu_scale_inv_ptr(); - - const double scale_tol = std::max(1., row_blocks * col_blocks * tol); - - for (int i = 0; i < row_blocks; i++) { - for (int j = 0; j < col_blocks; j++) { - const int idx = i * stride + j; - if (test[idx] != ref[idx]) { - int t_scale = static_cast(test[idx]); - int r_scale = static_cast(ref[idx]); - if (std::abs(t_scale - r_scale) == 1) { - mismatch_idx.emplace_back(i, j, r_scale-t_scale); - } else { - GTEST_FAIL() << "Error in " << name << std::endl - << "Mismatch: " << t_scale << " vs " - << r_scale << " at index " << idx; + if (abs_delta > atol) { + mismatches_num++; + mismatch_indices.push_back(idx); + } + if (mismatches_num > tolerable_mismatches_limit) { + std::cout << "Error in " << name << std::endl; + for (const int index : mismatch_indices) { + std::cout << "Mismatch at (" << index << "):" + << static_cast(test[index]) << " vs " + << static_cast(ref[index]) << std::endl; } + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "."; } } } - const size_t scale_mismatches = mismatch_idx.size(); - - ASSERT_FALSE(scale_mismatches > scale_tol) - << "Error in " << name << std::endl << std::setprecision(4) - << "Total scale mismatches: " << scale_mismatches << " (" << 100.*(double)scale_mismatches/(double)(row_blocks*col_blocks) - << "%) Exceeds tolerance of " << scale_tol << " (" << 100.*tol << "%) mismatches"; - - if (scale_mismatches) { - std::cout << "\x1b[33mWARNING:\x1b[0m " << scale_mismatches - << " scale mismatches were found. This does not imply an accuracy issue." << std::endl; - } } -void adjust_ref(std::vector> mismatch_idx, void *ref, const size_t row_blocks, - const size_t col_blocks, const size_t rows, const size_t cols, DType otype) { - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( otype, T, - T *ref_data = reinterpret_cast(ref); + +#ifdef __HIP_PLATFORM_AMD__ +void adjust_ref_for_e8m0_scale_error(const std::string &name, + const std::vector &mismatch_idx, + const uint8_t *test_scale, const uint8_t *ref_scale, + const size_t scale_stride, const size_t rows, + const size_t cols, bool rowwise, void *ref_ptr, DType otype) { + if (mismatch_idx.size() == 0) { + return; + } + const size_t col_blocks_size = rowwise ? 32 : 1; + const size_t row_blocks_size = rowwise ? 1 : 32; + GTEST_LOG_(INFO) << "Adjusting reference data for " << mismatch_idx.size() + << " scale mismatches in tensor " << name << " " + << (rowwise ? "rowwise" : "colwise") << " direction." << std::endl; + for (const auto scale_idx : mismatch_idx) { + const int scale_diff = ref_scale[scale_idx] - test_scale[scale_idx]; double scale_val; - const size_t col_blocks_size = cols / col_blocks; - const size_t row_blocks_size = rows / row_blocks; - for (const auto &[i, j, scale_diff] : mismatch_idx) { - if (scale_diff == 1) { - scale_val = 2.; - } else if (scale_diff == -1) { - scale_val = .5; - } else { // Shouldn't ever reach this - GTEST_FAIL() << "Error in adjust_ref, |scale_diff| > 1"; - } - size_t ii_min = i * row_blocks_size; - const size_t ii_max = std::min(ii_min + row_blocks_size, rows); - for (; ii_min < ii_max; ii_min++) { - size_t jj_min = j * col_blocks_size; - const size_t jj_max = std::min(jj_min + col_blocks_size, cols); - for (; jj_min < jj_max; jj_min++) { - const size_t data_idx = ii_min * cols + jj_min; + if (scale_diff == 1) { + scale_val = 2.; + } else if (scale_diff == -1) { + scale_val = .5; + } else { + GTEST_FAIL() << "Error in " << name << ": mismatch " << test_scale[scale_idx] << " vs " + << ref_scale[scale_idx] << " at index " << scale_idx; + } + const int i = scale_idx / scale_stride; + const int j = scale_idx % scale_stride; + size_t ii_min = i * row_blocks_size; + const size_t ii_max = std::min(ii_min + row_blocks_size, rows); + for (; ii_min < ii_max; ii_min++) { + size_t jj_min = j * col_blocks_size; + const size_t jj_max = std::min(jj_min + col_blocks_size, cols); + for (; jj_min < jj_max; jj_min++) { + const size_t data_idx = ii_min * cols + jj_min; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(otype, T, { + T *ref_data = reinterpret_cast(ref_ptr); ref_data[data_idx] = static_cast(static_cast(ref_data[data_idx]) * scale_val); - } + }); // NOLINT(*) } } - ); // NOLINT(*) + } } #endif // #ifdef __HIP_PLATFORM_AMD__ @@ -897,9 +903,18 @@ void fillCase(Tensor *t, const InputsFillCase fill_case) { } } +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); template void fillCase(Tensor *t, const InputsFillCase fill_case); template void fillCase(Tensor *t, const InputsFillCase fill_case); -template void fillCase(Tensor *t, const InputsFillCase fill_case); +#if FP4_TYPE_SUPPORTED +template void fillCase(Tensor *t, const InputsFillCase fill_case); +#endif void setRandomScale(Tensor *t) { std::uniform_real_distribution<> dis(-2.0, 1.0); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index bfb46f8a0..3c0a387c6 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -437,7 +437,12 @@ inline fp8e8m0 float_to_e8m0(float val) { } inline float exp2f_rcp(fp8e8m0 biased_exp) { - return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); + if (biased_exp == 0) { + return 1.0f; + } + int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127) + float fp32_val = *reinterpret_cast(&int_val); + return fp32_val; } inline float identity(const float x) { return x; } @@ -469,22 +474,25 @@ size_t last_dimension(const std::vector &shape); bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); void compareResults(const std::string &name, const Tensor &test, const void *ref, - bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true); + bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, + const size_t tolerable_mismatches_limit = 0); void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t N, float mismatch_rate_tol = 0.); void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride); -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t N); -#ifdef USE_ROCM -void compare_e8m0_scaling_factors(const std::string &name, Tensor &output, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - double tol, bool rowwise, std::vector> &mismatch_idx); + const size_t row_blocks, const size_t col_blocks, const size_t stride, + std::vector &mismatch_indices, size_t& mismatches_num, + const size_t scale_diff_abs_tolerance = 0, + const double abs_tolerable_mismatches_limit = 0, + const double rel_tolerable_mismatches_limit = 0); -void adjust_ref(std::vector> mismatch_idx, void *ref, const size_t row_blocks, - const size_t col_blocks, const size_t rows, const size_t cols, DType otype); +#ifdef USE_ROCM +void adjust_ref_for_e8m0_scale_error(const std::string &name, + const std::vector &mismatch_idx, + const uint8_t *test_scale, const uint8_t *ref_scale, + const size_t scale_stride, const size_t rows, + const size_t cols, bool rowwise, void *ref_ptr, DType otype); #endif std::array get_scale_tensor_dims(const size_t rows, const size_t cols, diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt new file mode 100644 index 000000000..ed3ddeb88 --- /dev/null +++ b/tests/cpp_distributed/CMakeLists.txt @@ -0,0 +1,57 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +cmake_minimum_required(VERSION 3.18) + +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + endif() +endif() + + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +project(transformer_engine_distributed_tests LANGUAGES CUDA CXX) + +add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest) + +include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) + +if(NOT DEFINED TE_LIB_PATH) + execute_process(COMMAND bash -c "python3 -c 'import transformer_engine as te; print(te.__file__)'" + OUTPUT_VARIABLE TE_LIB_FILE + OUTPUT_STRIP_TRAILING_WHITESPACE) + get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) +endif() + +find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) + +message(STATUS "Found transformer_engine library: ${TE_LIB}") +include_directories(../../transformer_engine/common/include) +include_directories(../../transformer_engine/common) +include_directories(../../transformer_engine) +include_directories(${CMAKE_SOURCE_DIR}) + +find_package(CUDAToolkit REQUIRED) + +add_executable(test_comm_gemm + test_comm_gemm.cu + ../cpp/test_common.cu) + +find_package(OpenMP REQUIRED) +find_package(MPI REQUIRED) +find_library(NCCL_LIB + NAMES nccl libnccl + PATH_SUFFIXES lib + REQUIRED) +target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) +target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) + +include(GoogleTest) +gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp_distributed/test_comm_gemm.cu b/tests/cpp_distributed/test_comm_gemm.cu new file mode 100644 index 000000000..8355d5f96 --- /dev/null +++ b/tests/cpp_distributed/test_comm_gemm.cu @@ -0,0 +1,441 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../cpp/test_common.h" +#include "common.h" + +using transformer_engine::DType; +using transformer_engine::TypeInfo; + +#define CHECK_MPI(expr) \ + do { \ + int err = (expr); \ + if (err != MPI_SUCCESS) { \ + char err_str[MPI_MAX_ERROR_STRING + 1]{}; \ + int _len{}; \ + MPI_Error_string(err, err_str, &_len); \ + EXPECT_TRUE(false) << "MPI error: " << err << ": " << err_str; \ + } \ + } while (false) + +#define CHECK_NCCL(expr) \ + do { \ + ncclResult_t err = (expr); \ + if (err != ncclSuccess) { \ + EXPECT_TRUE(false) << "NCCL error: " << err << ": " << ncclGetErrorString(err); \ + } \ + } while (false) + +#define CHECK_CU(expr) \ + do { \ + CUresult err = (expr); \ + if (err != CUDA_SUCCESS) { \ + const char* str{}; \ + CUresult e_str = cuGetErrorString(err, &str); \ + if (e_str != CUDA_SUCCESS) str = "(unknown)"; \ + EXPECT_TRUE(false) << "CU error: " << err << ": " << str; \ + } \ + } while (false) + +int main(int argc, char* argv[]) { + ::testing::InitGoogleTest(&argc, argv); + CHECK_MPI(MPI_Init(&argc, &argv)); + auto ret = RUN_ALL_TESTS(); + CHECK_MPI(MPI_Finalize()); + return ret; +} + +bool IsMulticastSupported(int device_id) { + int supported = 0; + CHECK_CU(cuDeviceGetAttribute(&supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, device_id)); + return supported; +} + +template +std::vector CopyMatrix(const std::vector& data, size_t mstart, size_t nstart, size_t msize, + size_t nsize, size_t ld) { + std::vector ret(msize * nsize); + size_t dst = 0; + for (size_t j = nstart; j < nstart + nsize; ++j) { + for (size_t i = mstart; i < mstart + msize; ++i) { + ret[dst++] = data[j * ld + i]; + } + } + return ret; +} + +template +test::Tensor Make(size_t m, size_t n, float scale) { + test::Tensor ret("", std::vector{n, m}, TypeInfo::dtype); + ret.set_scale(scale); + ret.set_scale_inv(1.0 / scale); + return ret; +} + +template +test::Tensor MakeFromData(const std::vector& data, size_t mstart, size_t nstart, size_t msize, + size_t nsize, size_t ld, float scale) { + test::Tensor ret("", std::vector{nsize, msize}, TypeInfo::dtype); + ret.set_scale(scale); + ret.set_scale_inv(1.0 / scale); + auto local = CopyMatrix(data, mstart, nstart, msize, nsize, ld); + NVTE_CHECK_CUDA(cudaMemcpy(ret.rowwise_dptr(), local.data(), local.size() * sizeof local[0], + cudaMemcpyDefault)); + return ret; +} + +template +float GetScale(float amax) { + if constexpr (sizeof(T) > 1) return 1.0; + return static_cast(static_cast(std::numeric_limits::max())) / amax; +} + +struct Params { + DType a_type; + DType b_type; + DType d_type; + bool transa; + bool transb; + size_t m; + size_t n; + size_t k; + float tol; +}; + +class CommGemmFixure : public ::testing::TestWithParam { + protected: + CommGemmFixure() { + CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &nranks_)); + CHECK_MPI(MPI_Comm_rank(MPI_COMM_WORLD, &rank_)); + NVTE_CHECK_CUDA(cudaSetDevice(rank_)); + ncclUniqueId id{}; + if (rank_ == 0) CHECK_NCCL(ncclGetUniqueId(&id)); + CHECK_MPI(MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); + CHECK_NCCL(ncclCommInitRank(&comm_, nranks_, id, rank_)); + ctx_ = nvte_comm_gemm_ctx_create(comm_, nranks_, rank_); + } + ~CommGemmFixure() { + nvte_comm_gemm_ctx_destroy(ctx_); + ncclCommDestroy(comm_); + } + + struct PatternDims { + int64_t a_rows_start; + int64_t a_rows_num; + int64_t a_cols_start; + int64_t a_cols_num; + int64_t b_rows_start; + int64_t b_rows_num; + int64_t b_cols_start; + int64_t b_cols_num; + int64_t d_rows_start; + int64_t d_rows_num; + int64_t d_cols_start; + int64_t d_cols_num; + }; + + virtual PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) = 0; + + virtual void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) = 0; + + template + void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) { + cudaStream_t stream{}; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + constexpr float MAX_IN = 1.0; + std::mt19937 rng(12); + std::uniform_real_distribution dist(0.0, MAX_IN); + + float a_scale = GetScale(MAX_IN); + float b_scale = GetScale(MAX_IN); + float d_scale = GetScale(MAX_IN * MAX_IN * k); + float bias_scale = GetScale(MAX_IN); + + std::vector adata(m * k); + std::generate(adata.begin(), adata.end(), + [&rng, &dist, a_scale] { return static_cast(dist(rng) * a_scale); }); + std::vector bdata(k * n); + std::generate(bdata.begin(), bdata.end(), + [&rng, &dist, b_scale] { return static_cast(dist(rng) * b_scale); }); + std::vector biasdata(m * n); + std::generate(biasdata.begin(), biasdata.end(), [&rng, &dist, bias_scale] { + return static_cast(dist(rng) * bias_scale); + }); + + auto ga = transa ? MakeFromData(adata, 0, 0, k, m, k, a_scale) + : MakeFromData(adata, 0, 0, m, k, m, a_scale); + auto gb = transb ? MakeFromData(bdata, 0, 0, n, k, n, b_scale) + : MakeFromData(bdata, 0, 0, k, n, k, b_scale); + auto gbias = MakeFromData(biasdata, 0, 0, m, n, m, bias_scale); + auto gd = Make(m, n, d_scale); + auto gaux = Make(m, n, d_scale); + + auto dims = DistributeTensors(m, n, k); + auto a = transa ? MakeFromData(adata, dims.a_rows_start, dims.a_cols_start, + dims.a_rows_num, dims.a_cols_num, k, a_scale) + : MakeFromData(adata, dims.a_cols_start, dims.a_rows_start, + dims.a_cols_num, dims.a_rows_num, m, a_scale); + auto b = transb ? MakeFromData(bdata, dims.b_cols_start, dims.b_rows_start, + dims.b_cols_num, dims.b_rows_num, n, b_scale) + : MakeFromData(bdata, dims.b_rows_start, dims.b_cols_start, + dims.b_rows_num, dims.b_cols_num, k, b_scale); + auto bias = MakeFromData(biasdata, dims.d_rows_start, dims.d_cols_start, + dims.d_rows_num, dims.d_cols_num, m, bias_scale); + auto d = Make(dims.d_rows_num, dims.d_cols_num, d_scale); + auto aux = Make(dims.d_rows_num, dims.d_cols_num, d_scale); + + bool grad = false; + bool accumulate = false; + CommGemm(m, n, k, a.data(), b.data(), d.data(), bias.data(), aux.data(), transa, transb, grad, + accumulate, 0 /*comm_sm_count*/, stream); + auto workspace = Make(1, 32 << 20, 1.0); + nvte_cublas_gemm(ga.data(), gb.data(), gd.data(), gbias.data(), gaux.data(), transa, transb, + grad, workspace.data(), accumulate, false /* use_split_accumulator */, + 0 /* math_sm_count */, stream); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); + std::vector out(dims.d_rows_num * dims.d_cols_num); + NVTE_CHECK_CUDA( + cudaMemcpy(out.data(), d.rowwise_dptr(), out.size() * sizeof out[0], cudaMemcpyDefault)); + std::vector out_golden_global(m * n); + NVTE_CHECK_CUDA(cudaMemcpy(out_golden_global.data(), gd.rowwise_dptr(), + out_golden_global.size() * sizeof out_golden_global[0], + cudaMemcpyDefault)); + + auto out_golden = CopyMatrix(out_golden_global, dims.d_rows_start, dims.d_cols_start, + dims.d_rows_num, dims.d_cols_num, m); + NVTE_CHECK(out.size() == out_golden.size()); + for (size_t i = 0; i < out.size(); ++i) { + EXPECT_NEAR(static_cast(out[i]), static_cast(out_golden[i]), tol * k); + } + } + + NVTECommGemmCtx* ctx_{}; + int nranks_{}; + int rank_{}; + ncclComm_t comm_{}; +}; + +struct AgGemm : public CommGemmFixure { + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { + auto a_cols_num = nvte_comm_gemm_numroc(ctx_, m); + auto b_cols_num = nvte_comm_gemm_numroc(ctx_, n); + + int64_t a_cols_start{}; + int64_t b_cols_start{}; + MPI_Exscan(&a_cols_num, &a_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + MPI_Exscan(&b_cols_num, &b_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + + return PatternDims{ + .a_rows_start = 0, + .a_rows_num = k, + .a_cols_start = a_cols_start, + .a_cols_num = a_cols_num, + .b_rows_start = 0, + .b_rows_num = k, + .b_cols_start = b_cols_start, + .b_cols_num = b_cols_num, + .d_rows_start = a_cols_start, + .d_rows_num = a_cols_num, + .d_cols_start = 0, + .d_cols_num = n, + }; + } + + void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) override { + nvte_all_gather_gemm(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad, + accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault); + } +}; + +struct GemmRs : public CommGemmFixure { + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { + auto rows_num = nvte_comm_gemm_numroc(ctx_, k); + auto d_cols_num = nvte_comm_gemm_numroc(ctx_, n); + + int64_t rows_start{}; + int64_t d_cols_start{}; + MPI_Exscan(&rows_num, &rows_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + MPI_Exscan(&d_cols_num, &d_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + + return PatternDims{ + .a_rows_start = rows_start, + .a_rows_num = rows_num, + .a_cols_start = 0, + .a_cols_num = m, + .b_rows_start = rows_start, + .b_rows_num = rows_num, + .b_cols_start = 0, + .b_cols_num = n, + .d_rows_start = 0, + .d_rows_num = m, + .d_cols_start = d_cols_start, + .d_cols_num = d_cols_num, + }; + } + + void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) override { + nvte_gemm_reduce_scatter(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad, + accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault); + } +}; + +struct GemmAr : public CommGemmFixure { + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { + auto rows_num = nvte_comm_gemm_numroc(ctx_, k); + + int64_t rows_start{}; + MPI_Exscan(&rows_num, &rows_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + + return PatternDims{ + .a_rows_start = rows_start, + .a_rows_num = rows_num, + .a_cols_start = 0, + .a_cols_num = m, + .b_rows_start = rows_start, + .b_rows_num = rows_num, + .b_cols_start = 0, + .b_cols_num = n, + .d_rows_start = 0, + .d_rows_num = m, + .d_cols_start = 0, + .d_cols_num = n, + }; + } + + void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) override { + nvte_gemm_all_reduce(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad, + accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault); + } + + void SetUp() override { + if (!IsMulticastSupported(rank_)) + GTEST_SKIP() << "Multicast is not supported on device " << rank_; + } +}; + +TEST_P(AgGemm, Gemm) { + auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + a_type, AType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + b_type, BType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + d_type, DType, Run(transa, transb, m, n, k, tol);))); +} + +TEST_P(GemmRs, Gemm) { + auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + a_type, AType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + b_type, BType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + d_type, DType, Run(transa, transb, m, n, k, tol);))); +} + +TEST_P(GemmAr, Gemm) { + auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + a_type, AType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + b_type, BType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + d_type, DType, Run(transa, transb, m, n, k, tol);))); +} + +std::string ParamSuffix(const testing::TestParamInfo& info) { + const auto [a_type, b_type, d_type, transa, transb, m, n, k, _tol] = info.param; + std::ostringstream ss; + ss << static_cast(a_type) << "_" << static_cast(b_type) << "_" + << static_cast(d_type) << "_" << (transa ? "T" : "N") << (transb ? "T" : "N") << "_" << m + << "x" << n << "x" << k; + return ss.str(); +} + +INSTANTIATE_TEST_SUITE_P(AgGemm, AgGemm, + testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, true, 256, 128, 64, 1e-3}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + true, false, 256, 128, 64, 1e-3}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, false, 256, 128, 64, 1e-3}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, true, 256, 128, 64, 1e-3}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, true, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat8E4M3, DType::kFloat8E4M3, + DType::kFloat16, true, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat8E4M3, DType::kFloat8E5M2, + DType::kFloat16, true, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat8E5M2, DType::kFloat8E4M3, + DType::kFloat16, true, false, 256, 128, 64, 1e-3}), + &ParamSuffix); + +INSTANTIATE_TEST_SUITE_P(GemmRs, GemmRs, + testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, true, 64, 128, 256, 5e-2}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + true, false, 64, 128, 256, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, false, 64, 128, 256, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, true, 64, 128, 256, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, true, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E4M3, + DType::kFloat16, true, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E5M2, + DType::kFloat16, true, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat8E5M2, DType::kFloat8E4M3, + DType::kFloat16, true, false, 64, 128, 256, 5e-2}), + &ParamSuffix); + +INSTANTIATE_TEST_SUITE_P( + GemmAr, GemmAr, + testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, true, false, 64, + 64 * 4, 64 * 4, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, true, false, 64, + 64 * 4, 64 * 4, 5e-2}, + Params{DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kFloat16, true, false, + 128, 128 * 4, 128 * 4, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kFloat16, true, false, + 128, 128 * 4, 128 * 4, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kFloat16, true, false, + 128, 128 * 4, 128 * 4, 5e-2}), + &ParamSuffix); diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 663a95418..cb5676d51 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -5,6 +5,8 @@ import os import jax import pytest +from collections import defaultdict +import time import transformer_engine.jax @@ -32,3 +34,54 @@ def enable_fused_attn_after_hopper(): yield if "NVTE_FUSED_ATTN" in os.environ: del os.environ["NVTE_FUSED_ATTN"] + + +class TestTimingPlugin: + """ + Plugin to measure test execution time. Enable test timing by setting NVTE_JAX_TEST_TIMING=1 + in the environment. + """ + + def __init__(self): + self.test_timings = defaultdict(list) + + @pytest.hookimpl(tryfirst=True) + def pytest_runtest_setup(self, item): + item._timing_start = time.time() + + @pytest.hookimpl(trylast=True) + def pytest_runtest_teardown(self, item, nextitem): + if hasattr(item, "_timing_start"): + duration = time.time() - item._timing_start + + # Extract base function name without parameters + test_name = item.name + if "[" in test_name: + base_name = test_name.split("[")[0] + else: + base_name = test_name + + self.test_timings[base_name].append(duration) + + def pytest_sessionfinish(self, session, exitstatus): + print("\n" + "=" * 80) + print("TEST RUNTIME SUMMARY (grouped by function)") + print("=" * 80) + + total_overall = 0 + for test_name, durations in sorted(self.test_timings.items()): + total_time = sum(durations) + count = len(durations) + avg_time = total_time / count if count > 0 else 0 + total_overall += total_time + + print(f"{test_name:<60} | {count:3}x | {total_time:7.2f}s | avg: {avg_time:6.2f}s") + + print("=" * 80) + print(f"{'TOTAL RUNTIME':<60} | {'':>3} | {total_overall:7.2f}s |") + print("=" * 80) + + +def pytest_configure(config): + if os.getenv("NVTE_JAX_TEST_TIMING", "0") == "1": + config.pluginmanager.register(TestTimingPlugin(), "test_timing") diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 4caa0f027..3f3b5db84 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -25,7 +25,7 @@ def generate_configs(): pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1") ) configs.append( - pytest.param(2, (2,), ("tp",), MeshResource(tp_resource="tp"), id="n2_dp1_tp2") + pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2") ) if is_devices_enough(4): @@ -33,8 +33,8 @@ def generate_configs(): pytest.param( 4, (2, 2), - ("dp", "tp"), - MeshResource(dp_resource="dp", tp_resource="tp"), + ("dp", "tpsp"), + MeshResource(dp_resource="dp", tpsp_resource="tpsp"), id=f"n4_dp2_tp2", ) ) @@ -42,20 +42,28 @@ def generate_configs(): return configs -def generate_context_parallel_configs(): - configs = [] - mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp") - axes = ("dp", "cp", "tp") +def generate_context_parallel_configs_for_attn(): + """Generate CP combinations along with TP+DP for TestDistributedContextParallelSelfAttn only""" + configsL1 = [] + configsL2 = [] + mr = MeshResource(dp_resource="dp", cp_resource="cp", tpsp_resource="tpsp") + axes = ("dp", "cp", "tpsp") DP_sizes = (1, 2) CP_sizes = (1, 2, 4, 8) TP_sizes = (1, 2) for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): ndev = cp * tp * dp if is_devices_enough(ndev): - configs.append( - pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") - ) - + # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations) + if cp != 1: + configsL1.append( + pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") + ) + else: + configsL2.append( + pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") + ) + configs = {"L0": [], "L1": configsL1, "L2": configsL2} return configs diff --git a/tests/jax/multi_process_launch.sh b/tests/jax/multi_process_launch.sh new file mode 100644 index 000000000..fcb066de7 --- /dev/null +++ b/tests/jax/multi_process_launch.sh @@ -0,0 +1,23 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +#!/bin/bash + +SCRIPT_NAME="${SCRIPT_NAME:-test.py}" + + +XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_command_buffer=''" + +export XLA_FLAGS="${XLA_BASE_FLAGS}" + +NUM_RUNS=$(nvidia-smi -L | wc -l) +for ((i=1; i /dev/null 2>&1 & +done + +CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS + +wait diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 20a8037eb..6ec3c27a4 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -34,6 +34,7 @@ from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax.cpp_extensions.misc import is_hip_extension from transformer_engine.jax.quantize import ( + NoScaleTensor, ScaledTensor, ScaledTensor1x, ScaledTensor2x, @@ -86,8 +87,14 @@ def is_shape_supported_by_mxfp8(input_shape): return False -def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): +def assert_bitwise_scaled_tensors( + a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True +): if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x): + if not precise_comparison: + assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype) + return + assert a.scaling_mode == b.scaling_mode assert a.scale_inv.dtype == b.scale_inv.dtype if a.scaling_mode.is_tensor_scaling(): @@ -102,8 +109,12 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): assert_allclose(a.data, b.data) elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x): - assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor) - assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor) + assert_bitwise_scaled_tensors( + a.rowwise_tensor, b.rowwise_tensor, precise_comparison=precise_comparison + ) + assert_bitwise_scaled_tensors( + a.colwise_tensor, b.colwise_tensor, precise_comparison=precise_comparison + ) else: pytest.fail("Unsupported input types") @@ -180,7 +191,7 @@ def assert_dequantized_grouped_scaled_tensor( class TestActivation: def ref_act(self, x, activation_type): - return _jax_act_lu(x, activation_type) + return _jax_act_lu(x, activation_type).data def value_n_grad_ref_func(self, x, activation_type): jitted_reference = jit( @@ -335,8 +346,8 @@ def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantize ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer) else: ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer) - # if isinstance(ln_out, ScaledTensor): - # ln_out = ln_out.dequantize() + # This is a no-op for non-quantized data + ln_out = ln_out.dequantize() return ln_out key = jax.random.PRNGKey(0) @@ -462,14 +473,23 @@ def _test_norm_forward( x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer ) ref_out, ref_mu, ref_rsigma = _jax_layernorm( - x, gamma, beta, zero_centered_gamma, epsilon, quantizer=ref_quantizer + x, + gamma, + beta, + zero_centered_gamma, + epsilon, + quantizer=ref_quantizer, ) else: output, rsigma = tex.rmsnorm_fwd( x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer ) ref_out, ref_rsigma = _jax_rmsnorm( - x, gamma, zero_centered_gamma, epsilon, quantizer=ref_quantizer + x, + gamma, + zero_centered_gamma, + epsilon, + quantizer=ref_quantizer, ) ref_mu = None @@ -489,24 +509,7 @@ def _test_norm_forward( # if the input dtype is not float32 precise_comparison = False - if precise_comparison: - assert_bitwise_scaled_tensors(output, ref_out) - else: - if isinstance(ref_out, ScaledTensor1x): - assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype) - elif isinstance(ref_out, ScaledTensor2x): - assert_allclose( - output.rowwise_tensor.dequantize(), - ref_out.rowwise_tensor.dequantize(), - dtype=out_dtype, - ) - assert_allclose( - output.colwise_tensor.dequantize(), - ref_out.colwise_tensor.dequantize(), - dtype=out_dtype, - ) - else: - pytest.fail("Unsupported output type") + assert_bitwise_scaled_tensors(output, ref_out, precise_comparison=precise_comparison) assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype) if norm_type == "layernorm": @@ -688,10 +691,6 @@ def test_grouped_qdq( n_groups=n_groups, ) - # grouped_quantize does not work with cudaGraph yet, so the jitting will breaks - # To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to - # disable cudaGraph, then use the following jitted function - scaled_tensor = tex.grouped_quantize( x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer ) @@ -776,12 +775,26 @@ def _test_quantize_dact_dbias( )(dz, x) if is_casted_output: - assert_bitwise_scaled_tensors(te_output, jax_output) + # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation + precise_comparison = not ( + in_dtype != jnp.float32 and scaling_mode.is_1d_block_scaling() + ) + assert_bitwise_scaled_tensors( + te_output, jax_output, precise_comparison=precise_comparison + ) else: - assert_allclose(te_output, jax_output) + assert isinstance(te_output, NoScaleTensor) + assert isinstance(jax_output, NoScaleTensor) + assert_allclose(te_output.data, jax_output.data) if is_dbias: - assert_allclose(te_dbias, jax_dbias) + # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. + precise_comparison = not ( + in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling() + ) + assert_allclose( + te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype + ) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @@ -866,15 +879,6 @@ def test_quantize_dact_dbias_mxfp8_scaling( ] -def _use_jax_fp8_gemm(enabled=False): - import os - - if enabled: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" - elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: - os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") - - class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): if data_layout[0] == "T": @@ -1039,8 +1043,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer) else: ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer) - if isinstance(ln_out, ScaledTensor): - ln_out = ln_out.dequantize() + ln_out = ln_out.dequantize() return ln_out @@ -1196,7 +1199,7 @@ def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2): bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) - x = _jax_act_lu(linear_1_out, activation_type) + x = _jax_act_lu(linear_1_out, activation_type).data linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ()))) if use_bias: bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape @@ -1327,16 +1330,14 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) - # grouped_gemm does not work with cudaGraph yet, so the jitting will breaks - # To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to - # disable cudaGraph, then use the following jitted function - # jitting grouped_gemm - # prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - # lhs, rhs, group_sizes, contracting_dims, - # ) + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + lhs, + rhs, + group_sizes, + contracting_dims, + ) - prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @@ -1365,12 +1366,7 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) - # jitting grouped_gemm - # prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))( - # lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set - # ) - - prim_out = tex.grouped_gemm( + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set ) @@ -1406,9 +1402,9 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) # jitting the grouped_dense - # value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), - # static_argnums=(4,)) - value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) + value_n_grad_prim_func = jit( + value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,) + ) ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( x, kernel, bias, group_sizes, contracting_dims @@ -1447,9 +1443,9 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) # jitting the grouped_dense - # value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), - # static_argnums=(4,)) - value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) + value_n_grad_prim_func = jit( + value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,) + ) ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( x, diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 5a824e8c6..28c1098e7 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -11,12 +11,13 @@ from jax import random from distributed_test_base import ( generate_configs, - generate_context_parallel_configs, + generate_context_parallel_configs_for_attn, generate_collectives_count, ) from packaging import version from transformer_engine.jax.cpp_extensions.misc import is_hip_extension from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat +from utils import pytest_parametrize_wrapper from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, AttnBiasType, @@ -32,6 +33,12 @@ DTYPES = [jnp.bfloat16] +DISTRIBUTED_SELF_ATTN_DATA_SHAPES = { + "L0": [()], + "L1": [(32, 1024, 16, 128)], + "L2": [(32, 512, 12, 64)], +} + class TestDistributedSelfAttn: @@ -42,8 +49,8 @@ def generate_collectives_count_ref( _, seqlen, heads, _ = shape is_dp_enabled = mesh_resource.dp_resource is not None tp_size = 1 - if mesh_resource.tp_resource is not None: - idx = mesh_axes.index(mesh_resource.tp_resource) + if mesh_resource.tpsp_resource is not None: + idx = mesh_axes.index(mesh_resource.tpsp_resource) tp_size = mesh_shape[idx] all_reduce_loss_bytes = 4 # 1 * FP32 @@ -68,7 +75,6 @@ def impl_test_self_attn( jax.config.update("jax_use_shardy_partitioner", use_shardy) dropout_prob = 0.0 is_training = True - batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( @@ -124,13 +130,7 @@ def impl_test_self_attn( runner.test_backward() @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize( - "data_shape", - [ - pytest.param((32, 512, 12, 64), id="32-512-12-64"), - pytest.param((32, 1024, 16, 128), id="32-1024-16-128"), - ], - ) + @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SELF_ATTN_DATA_SHAPES) @pytest.mark.parametrize( "attn_bias_type, bias_shape", [ @@ -199,6 +199,13 @@ def test_self_attn_shardy( ) +DISTRIBUTED_CROSS_ATTN_DATA_SHAPES = { + "L0": [()], + "L1": [[32, 512, 16, 64]], + "L2": [[32, 128, 12, 64]], +} + + class TestDistributedCrossAttn: def generate_collectives_count_ref(self): @@ -207,7 +214,7 @@ def generate_collectives_count_ref(self): return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]]) + @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_CROSS_ATTN_DATA_SHAPES) @pytest.mark.parametrize( "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK] ) @@ -407,8 +414,9 @@ def check_has_backend_for_mask(mask_type): runner.test_backward() del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] - @pytest.mark.parametrize( - "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_context_parallel_configs_for_attn(), ) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @@ -444,8 +452,9 @@ def test_context_parallel_allgather_attn_shardy( use_shardy=True, ) - @pytest.mark.parametrize( - "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_context_parallel_configs_for_attn(), ) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) @pytest.mark.parametrize("kv_groups", [1, 8]) @@ -486,8 +495,9 @@ def test_context_parallel_allgather_attn( use_shardy=False, ) - @pytest.mark.parametrize( - "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_context_parallel_configs_for_attn(), ) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) @pytest.mark.parametrize("kv_groups", [1, 8]) @@ -551,8 +561,9 @@ def test_context_parallel_ring_attn( ) @pytest.mark.skipif(version.parse(jax.__version__) < version.parse("0.5.0"), reason="shardy sharding requires JAX 0.5.0") - @pytest.mark.parametrize( - "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_context_parallel_configs_for_attn(), ) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @@ -589,16 +600,16 @@ def test_context_parallel_ring_attn_shardy( ) +REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = { + "L0": [[]], + "L1": [[3, 32, 8, 64]], + "L2": [[4, 32, 12, 32], [1, 16, 1, 1]], +} + + class TestReorderCausalLoadBalancing: @pytest.mark.parametrize("cp_size", [2, 4, 8]) - @pytest.mark.parametrize( - "shape", - [ - pytest.param([1, 16, 1, 1], id="1-16-1-1"), - pytest.param([4, 32, 12, 32], id="4-32-12-32"), - pytest.param([3, 32, 8, 64], id="3-32-8-64"), - ], - ) + @pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES) @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) @pytest.mark.parametrize( "reorder_strategy", diff --git a/tests/jax/test_distributed_helper.py b/tests/jax/test_distributed_helper.py new file mode 100644 index 000000000..e74e9aa6f --- /dev/null +++ b/tests/jax/test_distributed_helper.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import unittest + +import jax +import numpy as np + +from utils import pytest_parametrize_wrapper, is_devices_enough +from transformer_engine.jax.sharding import MeshResource, global_mesh_resource +from transformer_engine.jax import fp8_autocast + + +def generate_mesh_configs(): + configs = [] + if is_devices_enough(2): + configs.append( + [2, (1, 2), ("dp", "tpsp"), MeshResource(dp_resource="dp", tpsp_resource="tpsp")] + ) + if is_devices_enough(4): + configs.append( + [4, (2, 2), ("fsdp", "tp"), MeshResource(tp_resource="tp", fsdp_resource="fsdp")] + ) + return configs + + +class TestMeshResource(unittest.TestCase): + def test_fp8_autocast_with_mesh_resource(self): + for mesh_config in generate_mesh_configs(): + device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = jax.sharding.Mesh(devices, mesh_axes) + with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource): + self.assertEqual(mesh_resource, global_mesh_resource()) diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 57098b0e2..03c0d1119 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -27,6 +27,7 @@ NORM_INPUT_SHAPES = { "L0": [[64, 64]], + "L1": [[64, 64]], "L2": [[64, 64]], } diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 694610978..0af10d050 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -38,7 +38,11 @@ from transformer_engine.jax.quantize import QuantizerFactory from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability -from transformer_engine.jax.util import get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type +from transformer_engine.jax.util import ( + is_hip_extension, + get_jnp_float8_e4m3_type, + get_jnp_float8_e5m2_type, +) jnp_float8_e4m3_type = get_jnp_float8_e4m3_type() jnp_float8_e5m2_type = get_jnp_float8_e5m2_type() @@ -67,16 +71,16 @@ INTERMEDIATE = 64 -# Only test with FSDP and TP as DP is not used -def generate_fsdp_and_tp_configs(): +# Only test with FSDP and TPSP as DP is not used +def generate_fsdp_and_tpsp_configs(): configs = [] if is_devices_enough(2): configs.append( - [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] + [2, (1, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")] ) if is_devices_enough(4): configs.append( - [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] + [4, (2, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")] ) return configs @@ -178,7 +182,9 @@ def _test_layernorm_mlp_grad( ) # Single GPU - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with fp8_autocast( + enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource() + ): single_jitter = jax.jit( value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), @@ -189,14 +195,14 @@ def _test_layernorm_mlp_grad( devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast( - enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource ): - k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp")) - k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) + k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp")) + k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp")) k1_ = jax.device_put(k1, k1_sharding) k2_ = jax.device_put(k2, k2_sharding) if use_bias: - b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp")) + b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tpsp")) b1_ = jax.device_put(b1, b1_sharding) else: b1_sharding = b1_ = None @@ -230,12 +236,18 @@ def _test_layernorm_mlp_grad( multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) # TODO: skip cases with single fwd as nan/inf - if jnp.any(jnp.isnan(single_fwd)) or jnp.any(jnp.isinf(single_fwd)): + if is_hip_extension() and (jnp.any(jnp.isnan(single_fwd)) or + jnp.any(jnp.isinf(single_fwd))): pytest.skip("skip tests with nan/inf single fwd.") fwd_test_type = dtype if fp8_recipe is None else jnp_float8_e4m3_type bwd_test_type = dtype if fp8_recipe is None else jnp_float8_e5m2_type - assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) + + if fwd_test_type == jnp.float16 and use_bias: + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5) + else: + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) + for i in range(len(inputs)): if multi_grads[i] is not None: if isinstance(multi_grads[i], list): @@ -256,12 +268,12 @@ def _test_layernorm_mlp_grad( ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, @@ -285,12 +297,12 @@ def test_layernorm_mlp_grad( ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad_shardy( self, @@ -339,10 +351,9 @@ def _test_layernorm_mlp( with use_jax_gemm(enabled=with_jax_gemm): # Single GPUs - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): ln_mlp_single = LayerNormMLP( layernorm_type=layernorm_type, - transpose_batch_sequence=False, # input: [batch, seqlen, hidden] intermediate_dim=INTERMEDIATE, activations=activation_type, use_bias=use_bias, @@ -361,7 +372,6 @@ def _test_layernorm_mlp( ): ln_mlp_sharded = LayerNormMLP( layernorm_type=layernorm_type, - transpose_batch_sequence=False, intermediate_dim=INTERMEDIATE, activations=activation_type, scale_axes=LN_SCALE_AXES, @@ -419,7 +429,7 @@ def _test_layernorm_mlp( assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) @@ -440,7 +450,7 @@ def test_layernorm_mlp_layer( ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @@ -463,7 +473,7 @@ def test_layernorm_mlp_layer_fp8( ) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) @@ -484,7 +494,7 @@ def test_layernorm_mlp_layer_shardy( ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index bf70937ab..d9eaf314a 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -5,7 +5,6 @@ import warnings from functools import partial import pytest -from packaging import version import jax import jax.numpy as jnp @@ -42,11 +41,11 @@ def generate_inputs( if not bad_sharding: x_pspec = PartitionSpec( - mesh_resource.dp_resource, mesh_resource.tp_resource, None, None + mesh_resource.dp_resource, mesh_resource.tpsp_resource, None, None ) else: x_pspec = PartitionSpec( - mesh_resource.dp_resource, None, None, mesh_resource.tp_resource + mesh_resource.dp_resource, None, None, mesh_resource.tpsp_resource ) if broadcast_batch_mask: @@ -136,7 +135,7 @@ def impl_test_softmax( ) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]]) + @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]]) @pytest.mark.parametrize( "softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED], @@ -169,15 +168,14 @@ def test_softmax( dtype, bad_sharding, broadcast_batch_mask, - use_shardy=False, + use_shardy=True, ) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED]) @pytest.mark.parametrize("bad_sharding", [False, True]) @pytest.mark.parametrize("broadcast_batch_mask", [False, True]) - @pytest.mark.skipif(version.parse(jax.__version__) < version.parse("0.5.0"), reason="shardy sharding requires JAX 0.5.0") - def test_softmax_shardy( + def test_softmax_gspmd( self, device_count, mesh_shape, @@ -198,5 +196,5 @@ def test_softmax_shardy( dtype=DTYPES[0], bad_sharding=bad_sharding, broadcast_batch_mask=broadcast_batch_mask, - use_shardy=True, + use_shardy=False, ) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index e08c3a1b9..4d7718cd0 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -44,6 +44,7 @@ from transformer_engine_jax import ( NVTE_Fused_Attn_Backend, get_cudnn_version, + get_device_compute_capability, ) from distributed_test_base import assert_equal_collectives @@ -352,6 +353,14 @@ def _check_configs(self): "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) + if ( + get_device_compute_capability(0) == 100 + and self.dropout_prob == 0.1 + and self.attn_bias_type is not AttnBiasType.NO_BIAS + ): + pytest.skip( + "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" + ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate(): @@ -388,8 +397,12 @@ def _check_configs(self): self.head_dim_v, (-1, -1) if self.window_size is None else self.window_size, ).get_fused_attn_backend() - if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: - pytest.skip("Unsupported inputs combination or device compute capability.") + if is_hip_extension(): + if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: + pytest.skip("Unsupported inputs combination or device compute capability.") + else: + if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: + pytest.skip("Unsupported inputs combination or device compute capability.") if ( self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS @@ -413,7 +426,7 @@ def _setup_inputs(self): self.mesh = Mesh(self.devices, self.mesh_axes) self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1) self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1) - self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1) + self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1) # only support new-style RNGs on AMD hardware since they will crash otherwise if is_hip_extension() and not self.use_old_rng: @@ -651,7 +664,7 @@ def generate_random_segment_ids( self.qkvo_psec = PartitionSpec( self.mesh_resource.dp_resource, self.mesh_resource.cp_resource, - self.mesh_resource.tp_resource, + self.mesh_resource.tpsp_resource, None, ) self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec) @@ -679,7 +692,7 @@ def to_dp_shardings(x): if self.bias_shape == BiasShape._1HSS: self.bias_pspec = PartitionSpec( - None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None + None, self.mesh_resource.tpsp_resource, self.mesh_resource.cp_resource, None ) elif self.bias_shape == BiasShape._B1SS: self.bias_pspec = PartitionSpec( diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index e237318a4..e4511e1fe 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -14,10 +14,11 @@ from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.jax import fp8_autocast, get_delayed_scaling from transformer_engine.jax.quantize import ( - QuantizeConfig, + get_quantize_config, is_fp8_available, ScalingMode, update_collections, + TensorSource, ) from transformer_engine.jax.sharding import MeshResource, global_mesh_resource @@ -49,7 +50,7 @@ def test_update_collections(self): class TestFP8Functions(unittest.TestCase): def _check_default_state(self): - self.assertFalse(QuantizeConfig.is_fp8_enabled()) + self.assertFalse(get_quantize_config().is_fp8_enabled()) def _compare_delay_scaling(self, ref, test): self.assertTrue(ref.margin == test.margin) @@ -58,108 +59,90 @@ def _compare_delay_scaling(self, ref, test): self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) def _compare_current_scaling(self, test): - self.assertEqual(QuantizeConfig.MARGIN, test.margin) - self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) - self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING) + self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) + for tensor_source in TensorSource: + self.assertEqual( + get_quantize_config().get_scaling_mode(tensor_source), + ScalingMode.CURRENT_TENSOR_SCALING, + ) def _compare_mxfp8_scaling(self, test): - self.assertEqual(QuantizeConfig.MARGIN, test.margin) - self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) - self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING) + self.assertEqual(get_quantize_config().MARGIN, test.margin) + self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) + for tensor_source in TensorSource: + self.assertEqual( + get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING + ) @unittest.skipIf(not is_fp8_supported, reason=reason) def test_fp8_autocast_delayed_scaling(self): - QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() - with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()): + with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()): self._check_default_state() self._check_default_state() ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) - with fp8_autocast(enabled=True, fp8_recipe=ds): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_delay_scaling(get_delayed_scaling(), ds) self._check_default_state() ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1) - with fp8_autocast(enabled=True, fp8_recipe=ds): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_delay_scaling(get_delayed_scaling(), ds) self._check_default_state() - @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) + @unittest.skipIf(not is_fp8_supported, reason=reason) def test_fp8_autocast_current_scaling(self): - QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() - with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()): + with fp8_autocast( + enabled=False, fp8_recipe=Float8CurrentScaling(), mesh_resource=MeshResource() + ): self._check_default_state() self._check_default_state() - cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3) - with fp8_autocast(enabled=True, fp8_recipe=cs): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3) + with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_current_scaling(cs) self._check_default_state() - cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID) - with fp8_autocast(enabled=True, fp8_recipe=cs): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID) + with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_current_scaling(cs) self._check_default_state() @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) def test_fp8_autocast_mxfp8_block_scaling(self): - QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() - with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()): + with fp8_autocast( + enabled=False, fp8_recipe=MXFP8BlockScaling(), mesh_resource=MeshResource() + ): self._check_default_state() self._check_default_state() bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3) - with fp8_autocast(enabled=True, fp8_recipe=bs): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_mxfp8_scaling(bs) self._check_default_state() bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID) - with fp8_autocast(enabled=True, fp8_recipe=bs): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_mxfp8_scaling(bs) self._check_default_state() - - @unittest.skipIf(not is_fp8_supported, reason=reason) - def test_fp8_autocast_with_sharding_resource(self): - QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. - self._check_default_state() - - ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) - - mesh_s = ( - (MeshResource(None, None)), - (MeshResource("dp", None)), - (MeshResource(None, "tp")), - (MeshResource("dp", "tp")), - ) - # TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme - mesh_shape = (1, 1) - devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape) - with jax.sharding.Mesh(devices, ("dp", "tp")): - for sr in mesh_s: - with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) - self._compare_delay_scaling(get_delayed_scaling(), ds) - self.assertEqual(sr, global_mesh_resource()) - - self._check_default_state() diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index d59e13053..6f672ade7 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -23,11 +23,14 @@ from transformer_engine.common import recipe from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType from transformer_engine.jax.quantize import ( - QuantizeConfig, + get_quantize_config, ScalingMode, is_fp8_available, update_collections, + TensorSource, + fp8_autocast, ) +from transformer_engine.jax.sharding import MeshResource @pytest.fixture(autouse=True, scope="function") @@ -262,6 +265,16 @@ def enable_fused_attn(): _KEY_OF_RELATIVE_EMBEDDING: False, _KEY_OF_WINDOW_SIZE: (2, 2), }, + # attrs29 + { + _KEY_OF_RELATIVE_EMBEDDING: True, + _KEY_OF_SELF_ATTN_BIAS_TYPE: "pre_scale_bias", + }, + # attrs30 + { + _KEY_OF_RELATIVE_EMBEDDING: True, + _KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias", + }, ] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] @@ -345,7 +358,7 @@ def test_backward( ref_params, test_params = self._sync_params(ref_params, test_params) - if QuantizeConfig.is_fp8_enabled(): + if get_quantize_config().is_fp8_enabled(): for _ in range(4): _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)( inputs, @@ -354,12 +367,15 @@ def test_backward( test_others, test_layer, ) - if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: + if ( + get_quantize_config().get_scaling_mode(TensorSource.X) + == ScalingMode.DELAYED_TENSOR_SCALING + ): _, updated_quantize_meta = flax.core.pop( - updated_state[0], QuantizeConfig.COLLECTION_NAME + updated_state[0], get_quantize_config().COLLECTION_NAME ) test_others = update_collections( - {QuantizeConfig.COLLECTION_NAME: updated_quantize_meta}, test_others + {get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others ) del updated_quantize_meta del updated_state @@ -489,29 +505,33 @@ class BaseTester: def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" - QuantizeConfig.finalize() # Ensure FP8 disabled. - self.runner(attrs).test_forward(data_shape, dtype) + # Ensure FP8 disabled. + # Empty MeshResource is used as we are running on a single device + with fp8_autocast(enabled=False, mesh_resource=MeshResource()): + self.runner(attrs).test_forward(data_shape, dtype) def test_backward(self, data_shape, dtype, attrs): """Test normal datatype backward""" - QuantizeConfig.finalize() # Ensure FP8 disabled. - self.runner(attrs).test_backward(data_shape, dtype) + # Ensure FP8 disabled. + # Empty MeshResource is used as we are running on a single device + with fp8_autocast(enabled=False, mesh_resource=MeshResource()): + self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test forward with fp8 enabled""" - QuantizeConfig.initialize(fp8_recipe=fp8_recipe) - self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) - QuantizeConfig.finalize() + # Empty MeshResource is used as we are running on a single device + with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): + self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test backward with fp8 enabled""" - QuantizeConfig.initialize(fp8_recipe=fp8_recipe) - self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) - QuantizeConfig.finalize() + # Empty MeshResource is used as we are running on a single device + with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): + self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) class TestEncoderLayer(BaseTester): diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py new file mode 100644 index 000000000..31209d1bc --- /dev/null +++ b/tests/jax/test_multi_process_distributed_grouped_gemm.py @@ -0,0 +1,172 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from functools import partial + +import jax +import jax.numpy as jnp +import jax.experimental.multihost_utils as jem + +from transformer_engine.jax.dense import grouped_dense as te_grouped_dense +from transformer_engine.jax.quantize import ( + QuantizerFactory, + ScalingMode, +) + +from utils import assert_allclose, dtype_tols + + +N_GROUP = 8 +MESH_AXIS_NAME = "fsdp" + + +def test_grouped_gemm_fp8_allgather(data_shapes, kernel_fsdp_axis): + assert kernel_fsdp_axis in [1, 2] + x_shape, w_shape = data_shapes + + x_sharding = NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None, None, None)) + w_sharding = ( + NamedSharding(mesh, PartitionSpec(None, None, MESH_AXIS_NAME)) + if kernel_fsdp_axis == 2 + else NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None)) + ) + w_no_sharding = NamedSharding(mesh, PartitionSpec(None, None, None)) + + def init_data(): + x_key = jax.random.PRNGKey(0) + w_key = jax.random.PRNGKey(1) + x = jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16) + w = jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16) + w_amax = jnp.max(jnp.abs(w), axis=range(1, w.ndim)) + return x, w, w, w_amax + + def test_func(outter_x, outter_w, outter_w_amax): + in_specs = (x_sharding.spec, w_sharding.spec, None) + out_specs = x_sharding.spec + + @partial( + shard_map.shard_map, + mesh=mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + ) + def sharded_group_gemm(x, w, w_amax): + group_size = x.shape[0] + x_reshaped = x.reshape(-1, x.shape[-1]) + n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) + + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, + is_2x2x=True, + n_groups=group_size, + ) + + output = te_grouped_dense( + x_reshaped, + w, + n_groups, + kernel_amax=w_amax, + quantizer_set=quantizer_set, + kernel_fsdp_info=(MESH_AXIS_NAME, kernel_fsdp_axis), + ) + output = output.reshape(*x.shape[:-1], -1) + return output + + def run(x, w, w_amax): + output = sharded_group_gemm(x, w, w_amax) + return output + + output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_w_amax) + dx, dw, _ = vjp_fn(output) + return output, dx, dw + + def ref_func(outter_x, outter_w): + + in_specs = (x_sharding.spec, w_no_sharding.spec) + out_specs = x_sharding.spec + + @partial( + shard_map.shard_map, + mesh=mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + ) + def sharded_group_gemm(x, w): + group_size = x.shape[0] + x_reshaped = x.reshape(-1, x.shape[-1]) + n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) + + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, + is_2x2x=True, + n_groups=group_size, + ) + output = te_grouped_dense(x_reshaped, w, n_groups, quantizer_set=quantizer_set) + output = output.reshape(*x.shape[:-1], -1) + return output + + def run(x, w): + output = sharded_group_gemm(x, w) + return output + + output, vjp_fn = jax.vjp(run, outter_x, outter_w) + dx, dw = vjp_fn(output) + return output, dx, dw + + init_func = jax.jit(init_data, out_shardings=(x_sharding, w_sharding, w_no_sharding, None)) + x, w, w_global, w_amax = init_func() + + o_sharding = x_sharding + test_func_jitted = jax.jit( + test_func, + in_shardings=(x_sharding, w_sharding, None), + out_shardings=(o_sharding, x_sharding, w_sharding), + ) + ref_func_jitted = jax.jit( + ref_func, + in_shardings=(x_sharding, w_no_sharding), + out_shardings=(o_sharding, x_sharding, w_no_sharding), + ) + + out, dx, dw = test_func_jitted(x, w, w_amax) + ref_out, ref_dx, ref_dw = ref_func_jitted(x, w_global) + + e4m3_tols = dtype_tols(jnp.float8_e4m3fn) + e5m2_tols = dtype_tols(jnp.float8_e5m2) + + out, ref_out = jem.process_allgather((out, ref_out)) + dx, ref_dx = jem.process_allgather((dx, ref_dx)) + dw, ref_dw = jem.process_allgather((dw, ref_dw)) + + jnp.allclose(out, ref_out, **e4m3_tols) + jnp.allclose(dx, ref_dx, **e5m2_tols) + jnp.allclose(dw, ref_dw, **e5m2_tols) + + +if __name__ == "__main__": + from jax.sharding import NamedSharding, PartitionSpec + from jax.experimental import shard_map + import sys + + coord_addr = sys.argv[1] + proc_id = int(sys.argv[2]) + num_procs = int(sys.argv[3]) + + jax.distributed.initialize( + coordinator_address=coord_addr, num_processes=num_procs, process_id=proc_id + ) + + mesh = jax.make_mesh((num_procs,), (MESH_AXIS_NAME,)) + + with mesh: + data_shapes = [((4, 16, 128, 7168), (7168, 2048))] + for data_shape in data_shapes: + for kernel_fsdp_axis in [1, 2]: + test_grouped_gemm_fp8_allgather(data_shape, kernel_fsdp_axis) diff --git a/tests/jax/test_sharding.py b/tests/jax/test_sharding.py deleted file mode 100644 index 0d50b7345..000000000 --- a/tests/jax/test_sharding.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import pytest - -from transformer_engine.jax.flax import extend_logical_axis_rules -from transformer_engine.jax.sharding import global_shard_guard, MeshResource - -LOGICAL_RULES = [ - [(("a1", None), ("a2", "ma2")), False], - [(("a1", None), ("a2", "ma2"), ("a3", ("ma31", "ma32"))), True], - [(("a1", None), ("a2", "ma2"), ("a3", "ma31"), ("a3", "ma32")), False], - [(("a1", None), ("a2", "ma2"), ("batch", "batch_1200234")), True], - [(("a1", None), ("a2", "ma2"), ("a2", "ma1"), ("batch", "model"), ("batch", "data")), True], -] - -MeshS = [ - MeshResource(), - MeshResource("data", None), - MeshResource(None, "model"), - MeshResource("data", "model"), -] - - -class TestShardingSideAPI: - - @pytest.mark.parametrize("base_rules,need_assert", LOGICAL_RULES) - @pytest.mark.parametrize("sr", MeshS) - def test_extend_logical_axis_rules(self, base_rules, need_assert, sr): - with global_shard_guard(sr): - try: - target_te_rules = extend_logical_axis_rules(tuple()) - extended_rules = extend_logical_axis_rules(base_rules) - assert extended_rules == (*base_rules, *target_te_rules) - assert not need_assert - except AssertionError as ae: - assert need_assert, f"{ae.args}" diff --git a/tests/jax/utils.py b/tests/jax/utils.py index f34fb5448..56d5df8e3 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1607,16 +1607,18 @@ def print_debug_tensor_stats(prefix, tensor, hist=False): @contextmanager def use_jax_gemm(enabled=False): - orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None) + orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS", None) try: if enabled: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=false" + else: + os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=true" yield finally: if enabled: if orig_custom_calls_filter is None: - os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + os.environ.pop("NVTE_JAX_CUSTOM_CALLS") else: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter + os.environ["NVTE_JAX_CUSTOM_CALLS"] = orig_custom_calls_filter diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py similarity index 99% rename from tests/pytorch/fused_attn/run_fused_attn_with_cp.py rename to tests/pytorch/attention/run_attention_with_cp.py index 672950f50..10bb066a4 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -16,7 +16,7 @@ get_cu_seqlens_on_cp_rank, ) import transformer_engine_torch as tex -from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn +from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.common.recipe import DelayedScaling diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/attention/test_attention.py similarity index 77% rename from tests/pytorch/fused_attn/test_fused_attn.py rename to tests/pytorch/attention/test_attention.py index 04de02761..a5128653e 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/attention/test_attention.py @@ -1,18 +1,20 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import logging import math import os -from torch.utils.cpp_extension import IS_HIP_EXTENSION -from typing import Any, Dict, List, Tuple, Union, Optional -from contextlib import contextmanager +import sys +import pathlib +from typing import Any, Dict, Tuple, Union import pytest import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION + from transformer_engine.common import recipe from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init from transformer_engine.pytorch.attention.dot_product_attention import ( @@ -22,11 +24,8 @@ from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils, - get_attention_backend, check_set_window_size, - AttentionParams, ) -from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import RotaryPositionEmbedding import transformer_engine.pytorch.cpp_extensions as ext from transformer_engine.pytorch.cpp_extensions.fused_attn import ( @@ -51,21 +50,21 @@ restore_from_saved, ) +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ( + reset_rng_states, + ModelConfig, + dtype_tols, + get_available_attention_backends, +) + # Only run FP8 tests on H100 fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() -# Initialize RNG state seed = 1234 -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() - - -def reset_rng_states() -> None: - """Revert back to initial RNG state""" - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) +# Reset RNG states +reset_rng_states() @pytest.fixture(autouse=True) @@ -74,208 +73,30 @@ def reset_global_fp8_state(): fp8.FP8GlobalStateManager.reset() -class EnvVarCleaner: - def __init__(self, envs_): - self.envs = envs_ - self.flags = {} - for env in self.envs: - if env in os.environ: - self.flags[env] = os.environ[env] - def __del__(self): - for env in self.envs: - if env in self.flags: - os.environ[env] = self.flags[env] - else: - os.environ.pop(env, None) - - -@pytest.fixture(autouse=True) -def reset_attn_backend(): - env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", - "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON", - "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3"]) - yield - - -class ModelConfig: - def __init__( - self, - batch_size: int, - num_heads: int, - num_gqa_groups: int, - head_dim_qk: int, - max_seqlen_q: int, - max_seqlen_kv: int, - dropout_p: float, - attn_mask_type: str, - attn_bias_type: str, - head_dim_v: int = None, - alibi_type: str = "none", - num_layers: int = 1, - bias_shape: str = "1hss", - window_size: Tuple[int, int] = (-1, -1), - total_requests: int = None, - max_ctx_len: int = None, - ): - self.batch_size = batch_size - self.num_heads = num_heads - self.num_gqa_groups = num_gqa_groups - self.head_dim_qk = head_dim_qk - self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v - self.hidden_size = num_heads * head_dim_qk - self.hidden_size_kv = num_gqa_groups * self.head_dim_v - self.max_seqlen_q = max_seqlen_q - self.max_seqlen_kv = max_seqlen_kv - self.dropout_p = dropout_p - self.attn_mask_type = attn_mask_type - self.attn_bias_type = attn_bias_type - self.alibi_type = alibi_type - self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross" - self.num_layers = num_layers - self.bias_shape = bias_shape - self.window_size = window_size - self.total_requests = total_requests - self.max_ctx_len = max_ctx_len - - -@contextmanager -def logging_context(highest_level=logging.WARNING): - previous_level = logging.root.manager.disable - logging.disable(highest_level) - try: +if IS_HIP_EXTENSION: + from utils import EnvVarCleaner + @pytest.fixture(autouse=True) + def reset_attn_backend(): + env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", + "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON", + "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", "NVTE_FP8_DPA_BWD"]) yield - finally: - logging.disable(previous_level) - - -def _get_attention_backends( - config: ModelConfig, - qkv_dtype: torch.dtype, - qkv_layout: str, - window_size: Tuple[int, int] = (-1, -1), - pad_between_seqs: bool = False, - context_parallel: bool = False, - deterministic: bool = False, - fp8: bool = False, - fp8_meta: Optional[Dict[str, Any]] = None, - is_training: bool = True, - inference_params: Optional[InferenceParams] = None, -) -> Tuple[List, List]: - """Check if what attention backends support a model configuration""" - - os.environ["NVTE_FLASH_ATTN"] = "1" - os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_UNFUSED_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True - - alibi_slopes_shape = None - if config.attn_bias_type == "alibi" and config.alibi_type == "custom": - if config.bias_shape == "1hss": - alibi_slopes_shape = [config.num_heads] - if config.bias_shape == "bhss": - alibi_slopes_shape = [config.batch_size, config.num_heads] - - core_attention_bias_shape = ( - config.bias_shape if config.attn_bias_type == "post_scale_bias" else None - ) - core_attention_bias_requires_grad = False - # d=256 is supported by cuDNN 9.0+ for inference but not training - if ( - config.attn_bias_type == "post_scale_bias" - and config.head_dim_qk <= 128 - and config.head_dim_v <= 128 - ): - core_attention_bias_requires_grad = True - - fused_attn_backends = [] - available_backends = None - flash_attention_backend = None - fused_attention_backend = None - - def test(): - attention_params = AttentionParams( - qkv_dtype=qkv_dtype, - qkv_layout=qkv_layout, - batch_size=config.batch_size, - num_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - max_seqlen_q=config.max_seqlen_q, - max_seqlen_kv=config.max_seqlen_kv, - head_dim_qk=config.head_dim_qk, - head_dim_v=config.head_dim_v, - attn_mask_type=config.attn_mask_type, - window_size=window_size, - alibi_slopes_shape=alibi_slopes_shape, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias_shape=core_attention_bias_shape, - core_attention_bias_requires_grad=core_attention_bias_requires_grad, - pad_between_seqs=pad_between_seqs, - attention_dropout=config.dropout_p, - context_parallel=context_parallel, - deterministic=deterministic, - fp8=fp8, - fp8_meta=fp8_meta, - is_training=is_training, - inference_params=inference_params, - ) - ( - use_flash_attention, - flash_attention_backend, - use_fused_attention, - fused_attention_backend, - use_unfused_attention, - available_backends, - ) = get_attention_backend(attention_params) - # Set attention.py _attention_backends var using return value - # from get_attention_backend() - _attention_backends["use_flash_attention"] = use_flash_attention - _attention_backends["use_fused_attention"] = use_fused_attention - _attention_backends["flash_attention_backend"] = flash_attention_backend - _attention_backends["fused_attention_backend"] = fused_attention_backend - _attention_backends["use_unfused_attention"] = use_unfused_attention - _attention_backends["backend_selection_requires_update"] = False - return available_backends, flash_attention_backend, fused_attention_backend - - if IS_HIP_EXTENSION: - backends = {"AOTriton": "AOTRITON", "CK": "CK"} - with logging_context(): - for i in backends.keys(): - for k in backends.keys(): - os.environ["NVTE_FUSED_ATTN_"+backends[k]] = "0" - os.environ["NVTE_FUSED_ATTN_"+backends[i]] = "1" - _attention_backends["backend_selection_requires_update"] = True - available_backends, flash_attention_backend, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[i]: - fused_attn_backends.append(fused_attention_backend) - for i in backends.keys(): - del os.environ["NVTE_FUSED_ATTN_"+backends[i]] - available_backends[1] = len(fused_attn_backends) > 0 - else: - backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} - with logging_context(): - for i in range(len(backends)): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, flash_attention_backend, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) - return available_backends, flash_attention_backend, fused_attn_backends model_configs_base = { # test: b, h, hg, d, sq, skv, p, mask, bias - "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), - "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), - "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"), - "base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"), - "base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"), + "base_1_0": ModelConfig(8, 128, 16, 64), + "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256), + "base_2_0": ModelConfig(2, 2048, 24, 128), + "base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096), + "base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048), + "base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048), + "base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048), + "base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048), + "base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048), + "base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048), + "base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048), + "base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048), } @@ -292,10 +113,11 @@ def test_gqa_mla_thd(): Explicitly test dk_or_dv_reduce_thd as part of TE's CK integration post-processing for BWD FA with native padding support. """ - config = ModelConfig(8, 16, 4, 128, 128, 128, 0.0, "padding", "no_bias", head_dim_v=64) + # b, sq, h, dqk + config = ModelConfig(8, 128, 16, 128, num_gqa_groups= 4, head_dim_v=64, attn_mask_type="padding") qkv_layout = "thd_thd_thd" dtype = torch.float16 - _, _, fused_attn_backends = _get_attention_backends( + _, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -317,10 +139,11 @@ def test_dot_product_mem_calc(): if not is_bf16_compatible(): pytest.skip("This test requires bf16 support.") dtype = torch.bfloat16 - config = ModelConfig(16, 128, 8, 128, 8192, 8192, 0.0, "causal", "no_bias") + # b, sq, q, dqk + config = ModelConfig(16, 8192, 128, 128, num_gqa_groups=8, attn_mask_type="causal") is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128 qkv_layout = "sbhd_sbhd_sbhd" - _, _, fused_attn_backends = _get_attention_backends( + _, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -381,8 +204,8 @@ def test_dot_product_attention( config.window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) - is_training = config.head_dim_qk <= 192 and config.head_dim_v <= 128 - available_backends, _, fused_attn_backends = _get_attention_backends( + is_training = True + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -393,7 +216,7 @@ def test_dot_product_attention( flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: is_training = False - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -548,39 +371,21 @@ def test_dpa_checkpoint(dtype, model_configs, model): model_configs_mla = { # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend - "mla_1_0": ModelConfig( - 8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # self , 0 - "mla_1_1": ModelConfig( - 4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # cross, 0 - "mla_1_2": ModelConfig( - 4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # cross, 0 - "mla_2_0": ModelConfig( - 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64 - ), # self , 1 + "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0 + "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0 + "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0 + "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1 "mla_2_1": ModelConfig( - 1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64 + 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64 ), # cross, 1 "mla_2_2": ModelConfig( - 1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128 + 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 ), # cross, 1 - "mla_3_0": ModelConfig( - 8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64 - ), # inference - "mla_3_1": ModelConfig( - 8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # inference - "mla_3_2": ModelConfig( - 8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # inference - "mla_4_0": ModelConfig( - 10, 16, 16, 192, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=128 - ), - "mla_4_1": ModelConfig( - 10, 16, 16, 192, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), + "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference + "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference + "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference + "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference + "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference } @@ -595,40 +400,46 @@ def test_dpa_mla(dtype, model_configs, model): model_configs_mask = { # test: b, h, hg, d, sq, skv, p, mask, bias - "mask_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), - "mask_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "mask_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "mask_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), - "mask_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "mask_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "mask_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "mask_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "mask_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "mask_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "mask_5_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), + "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"), + "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"), + "mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"), + "mask_2_1": ModelConfig( + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right" + ), + "mask_2_2": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right" ), + "mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), + "mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"), + "mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), + "mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"), + "mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"), + "mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"), + "mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"), "mask_5_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right" ), "mask_5_2": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right" ), - "mask_6_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), - "mask_6_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), - "mask_7_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_7_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_8_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding", "no_bias"), - "mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"), - "mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"), - "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"), + "mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"), + "mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"), + "mask_7_0": ModelConfig( + 2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right" + ), + "mask_7_1": ModelConfig( + 2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right" + ), + "mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"), + "mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"), + "mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"), + "mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"), "mask_10_0": ModelConfig( - 2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right" ), "mask_10_1": ModelConfig( - 2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right" ), } @@ -644,44 +455,102 @@ def test_dpa_mask(dtype, model_configs, model): model_configs_bias = { # test: b, h, hg, d, sq, skv, p, mask, bias - "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), - "bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"), - "bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"), - "bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"), - "bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped - "bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped - "bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped - "bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped + "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"), + "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"), + "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"), + "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"), + "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped + "bias_1_5": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi" + ), # skipped + "bias_2_0": ModelConfig( + 4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias" + ), # skipped + "bias_2_1": ModelConfig( + 2, + 128, + 16, + 64, + max_seqlen_kv=256, + attn_mask_type="padding", + attn_bias_type="post_scale_bias", + ), # skipped "bias_2_2": ModelConfig( - 4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias" + 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias" ), # skipped "bias_2_3": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias" + 2, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="padding", + attn_bias_type="post_scale_bias", + ), # skipped + "bias_2_4": ModelConfig( + 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi" + ), # skipped + "bias_2_5": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi" ), # skipped - "bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped - "bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped - "bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), - "bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"), - "bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), + "bias_3_0": ModelConfig( + 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "bias_3_1": ModelConfig( + 2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "bias_3_2": ModelConfig( + 4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), "bias_3_3": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias" + 2, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="causal", + attn_bias_type="post_scale_bias", + ), # skipped + "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"), + "bias_3_5": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi" ), # skipped - "bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"), - "bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped "bias_4_0": ModelConfig( - 4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias" + 4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" ), # skipped "bias_4_1": ModelConfig( - 2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias" + 2, + 128, + 16, + 64, + max_seqlen_kv=256, + attn_mask_type="padding_causal", + attn_bias_type="post_scale_bias", ), # skipped "bias_4_2": ModelConfig( - 4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias" + 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" ), # skipped "bias_4_3": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias" + 2, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="padding_causal", + attn_bias_type="post_scale_bias", + ), # skipped + "bias_4_4": ModelConfig( + 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi" + ), # skipped + "bias_4_5": ModelConfig( + 2, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="padding_causal", + attn_bias_type="alibi", ), # skipped - "bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped - "bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped } @@ -696,33 +565,29 @@ def test_dpa_bias(dtype, model_configs, model): model_configs_bias_shapes = { # test: b, h, hg, d, sq, skv, p, - "bias_1_0": ModelConfig( + "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"), + "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"), + "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"), + "bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"), + "bias_1_4": ModelConfig( 4, - 16, - 16, - 64, - 128, + 2048, + 24, 128, - 0.0, - # mask, bias, bias_shape, - "no_mask", - "post_scale_bias", - bias_shape="11ss", - ), - "bias_1_1": ModelConfig( - 2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias", bias_shape="1hss" - ), - "bias_1_2": ModelConfig( - 4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="b1ss" - ), - "bias_1_3": ModelConfig( - 2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="bhss" - ), - "bias_1_4": ModelConfig( - 4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="1hss", alibi_type="custom" + attn_mask_type="causal", + attn_bias_type="alibi", + bias_shape="1hss", + alibi_type="custom", ), "bias_1_5": ModelConfig( - 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="bhss", alibi_type="custom" + 2, + 2048, + 24, + 128, + attn_mask_type="causal", + attn_bias_type="alibi", + bias_shape="bhss", + alibi_type="custom", ), } @@ -738,29 +603,31 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { # test: b, h, hg, d, sq, skv, p, mask, bias - "swa_1_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "swa_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), - "swa_2_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "swa_3_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), - "swa_4_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "swa_4_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "swa_5_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "swa_5_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "swa_5_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "swa_6_1": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + "swa_1_1": ModelConfig(2, 2048, 16, 64), + "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4), + "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096), + "swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), + "swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"), + "swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"), + "swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"), + "swa_3_2": ModelConfig( + 2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right" ), + "swa_3_3": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right" + ), + "swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), + "swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"), + "swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), + "swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"), + "swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), + "swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"), + "swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"), "swa_6_2": ModelConfig( - 2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right" ), "swa_6_3": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right" ), } @@ -776,13 +643,31 @@ def test_dpa_sliding_window(dtype, model_configs, model): model_configs_alibi_slopes = { # test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type - "alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"), - "alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"), + "alibi_1_0": ModelConfig( + 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla" + ), + "alibi_1_1": ModelConfig( + 1, + 128, + 16, + 64, + max_seqlen_kv=256, + attn_mask_type="causal", + attn_bias_type="alibi", + alibi_type="vanilla", + ), "alibi_2_0": ModelConfig( - 2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type="custom" + 2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom" ), "alibi_2_1": ModelConfig( - 1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type="custom" + 1, + 1024, + 24, + 128, + max_seqlen_kv=2048, + attn_mask_type="causal", + attn_bias_type="alibi", + alibi_type="custom", ), } @@ -812,16 +697,38 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): model_configs_layout = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), - "layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), - "layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"), - "layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), - "layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), - "layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), - "layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"), + "layout_0_0": ModelConfig(2, 128, 16, 64), + "layout_0_1": ModelConfig( + 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"), + "layout_0_3": ModelConfig( + 1, + 128, + 16, + 64, + max_seqlen_kv=256, + attn_mask_type="padding_causal", + attn_bias_type="post_scale_bias", + ), + "layout_1_0": ModelConfig(2, 2048, 24, 128), + "layout_1_1": ModelConfig( + 2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), + "layout_1_3": ModelConfig( + 1, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="padding_causal", + attn_bias_type="post_scale_bias", + ), + "layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048), + "layout_2_1": ModelConfig( + 2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), } @@ -838,55 +745,54 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_0_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_0_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "layout_2_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), + "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"), + "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), + "layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"), + "layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"), + "layout_1_2": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal" ), + "layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"), "layout_2_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right" ), "layout_2_2": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" - ), - "layout_3_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right" ), + "layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)), "layout_3_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4) ), "layout_3_2": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4) - ), - "layout_4_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4) ), + "layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)), "layout_4_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0) ), "layout_4_2": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0) ), "layout_5_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + 2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0) ), "layout_5_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + 2, + 2048, + 24, + 128, + num_gqa_groups=1, + attn_mask_type="padding_causal_bottom_right", + window_size=(4, 0), ), "layout_5_2": ModelConfig( 2, - 24, + 2048, 24, 128, - 2048, - 4096, - 0.0, - "padding_causal_bottom_right", - "no_bias", + max_seqlen_kv=4096, + attn_mask_type="padding_causal_bottom_right", window_size=(4, 0), ), } @@ -913,15 +819,17 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout, pad_between if (pad_between_seqs==False and get_cudnn_version() < (9, 3, 0)): pytest.skip("cuDNN 9.3.0+ is required to run pad_between_seqs = False"); - _, _, fused_attn_backends = _get_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - window_size=config.window_size, - pad_between_seqs=pad_between_seqs, - ) - if share_cu_seqlens_ref and FusedAttnBackend["CK"] not in fused_attn_backends: - pytest.skip("This test is only required for the CK fused attention backend.") + if share_cu_seqlens_ref: #ROCm specific config + _, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=pad_between_seqs, + ) + if FusedAttnBackend["CK"] not in fused_attn_backends: + pytest.skip("This test is only required for the CK fused attention backend.") + test_dot_product_attention( dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs, share_cu_seqlens_ref ) @@ -935,15 +843,16 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout, pad_between def test_dpa_qkv_layout_thd_mqa_gqa(dtype, model_configs, model, qkv_layout, pad_between_seqs, share_cu_seqlens_ref): config = model_configs[model] - _, _, fused_attn_backends = _get_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - window_size=config.window_size, - pad_between_seqs=pad_between_seqs, - ) - if share_cu_seqlens_ref and FusedAttnBackend["CK"] not in fused_attn_backends: - pytest.skip("This test is only required for the CK fused attention backend.") + if share_cu_seqlens_ref: + _, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=pad_between_seqs, + ) + if FusedAttnBackend["CK"] not in fused_attn_backends: + pytest.skip("This test is only required for the CK fused attention backend.") def find_factors(x): f = [] @@ -1325,16 +1234,22 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: model_configs_te_layer = { # test: b, h, hg, d, sq, skv, p, mask, bias - "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), - "te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), - "te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), - "te_1_3": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), - "te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), - "te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "te_2_3": ModelConfig(1, 16, 16, 64, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"), - "te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"), - "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), + "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), + "te_1_1": ModelConfig( + 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "te_1_2": ModelConfig( + 2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias" + ), + "te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"), + "te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"), + "te_2_1": ModelConfig(2, 2048, 16, 64), + "te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"), + "te_2_3": ModelConfig( + 1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right" + ), + "te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"), + "te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"), } @@ -1358,7 +1273,7 @@ def test_transformer_layer( # Test backend availability is_training = True - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=( @@ -1369,7 +1284,7 @@ def test_transformer_layer( flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: is_training = False - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=( @@ -1682,20 +1597,164 @@ def _run_transformer_layer( return out, inp.grad +model_configs_fp8_extra_state = { + "large": ModelConfig(2, 128, 4, 128, num_layers=1), +} + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.") +@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") +@pytest.mark.parametrize("model", ["large"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_sanity_attention_extra_state(model, dtype): + config = model_configs_fp8_extra_state[model] + # Test backend availability + is_training = True + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout="sb3hd", + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported and not flash_attn_supported: + pytest.skip("No attention backend available.") + + outputs = _run_attention_extra_state(dtype, config, checkpoint=False) + outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) + outputs_checkpoint_v1_6 = _run_attention_extra_state( + dtype, config, mimic_v1_6=True, checkpoint=True + ) + + # Check that results match + tols = dtype_tols(dtype) + if dtype in (torch.float16, torch.bfloat16): + tols.update(dict(rtol=2e-2, atol=2e-3)) + for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)): + torch.testing.assert_close( + test, + ref, + **tols, + ) + for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)): + torch.testing.assert_close( + test, + ref, + **tols, + ) + + +def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): + steps = 10 + path = "checkpoint.pt" + fp8_enabled = True + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=fp8_enabled, + fp8_mha=False, + ) + + reset_rng_states() + hidden_states = torch.randn( + (config.max_seqlen_q, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + + def get_model(dtype, config): + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + + with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe): + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.0, + attention_dropout=0.0, + fuse_qkv_params=True, + params_dtype=dtype, + device="cuda", + ) + return block + + block = get_model(dtype, config) + for i in range(steps // 2): + with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + output = block(hidden_states, None) + loss = output.sum() + loss.backward() + + if checkpoint: + sd = block.state_dict() + if mimic_v1_6: + sd["self_attention.core_attention.fused_attention._extra_state"] = sd[ + "self_attention.core_attention._extra_state" + ] + del sd["self_attention.core_attention._extra_state"] + torch.save(sd, path) + + param_grads = [] + for p in block.parameters(): + if p.requires_grad: + param_grads.append(p.grad.clone()) + + _cpu_rng_state_new = torch.get_rng_state() + _cuda_rng_state_new = torch.cuda.get_rng_state() + + del block + block = get_model(dtype, config) + block.load_state_dict(torch.load(path, weights_only=False)) + torch.set_rng_state(_cpu_rng_state_new) + torch.cuda.set_rng_state(_cuda_rng_state_new) + + for p in block.parameters(): + if p.requires_grad: + p.grad = param_grads.pop(0) + + assert not param_grads, "Oops!" + + for i in range((steps + 1) // 2): + with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + output = block(hidden_states, None) + loss = output.sum() + loss.backward() + + torch.cuda.synchronize() + + if os.path.exists(path): + os.remove(path) + + outputs = [output, hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) + + return outputs + + model_configs_fp8_vs_f16 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_9": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_10": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_11": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), - "fp8_12": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_13": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), - "fp8_15": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "fp8_16": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "fp8_17": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding", "no_bias"), - "fp8_18": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "fp8_19": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "fp8_20": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding_causal", "no_bias"), + "fp8_9": ModelConfig(2, 2048, 16, 128), + "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), + "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), + "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), + "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), + "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), + "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), + "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), + "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), + "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), + "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), } param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] @@ -1745,18 +1804,30 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] - if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( - 9, - 7, - 0, - ): - pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + # Test backend availability + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout=qkv_format.replace("hd", "h3d"), + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + # Skip if only unfused backend is supported + if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: + pytest.skip("Less than two backends to compare.") + if not fp8_dpa_bwd: + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_format.replace("hd", "h3d"), + is_training=is_training, + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("No attention backend available.") + + if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1782,11 +1853,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rtol = 5e-1 rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + if flash_attn_supported: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1960,23 +2027,34 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): # if get_device_compute_capability() >= (10, 0): # config.dropout_p = 0.1 - if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( - 9, - 7, - 0, - ): - pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") - if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: - pytest.skip("qkv_layout not applicable for MQA/GQA") - os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + # Test backend availability + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout=qkv_layout, + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + # Skip if only unfused backend is supported + if flash_attn_supported + fused_attn_supported < 1: + pytest.skip("No FP8 attention backend available.") + if not fp8_dpa_bwd: + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + is_training=is_training, + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("No attention backend available.") + if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: + pytest.skip("qkv_layout not applicable for MQA/GQA") + + if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -2005,11 +2083,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] logging.debug("========== {:^25s} ==========".format("forward output")) - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + if flash_attn_supported: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -2183,14 +2257,14 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: model_configs_fp8 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"), - "fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), - "fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"), - "fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"), - "fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "fp8_1": ModelConfig(1, 512, 1, 64), + "fp8_2": ModelConfig(4, 512, 16, 64), + "fp8_3": ModelConfig(1, 2048, 1, 128), + "fp8_4": ModelConfig(2, 2048, 24, 128), + "fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"), + "fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"), + "fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"), + "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), } param_types_fp8 = [torch.float16, torch.bfloat16] cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1")) @@ -2220,6 +2294,18 @@ def test_custom_mha_fp8_vs_f16(dtype, model): config = model_configs_fp8[model] + # Test backend availability + is_training = True + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd", + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not (fused_attn_backends and unfused_attn_supported): + pytest.skip("Not enough backends to run this test with.") + fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py similarity index 67% rename from tests/pytorch/fused_attn/test_fused_attn_with_cp.py rename to tests/pytorch/attention/test_attention_with_cp.py index edf518d6b..ece5a37de 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -1,11 +1,14 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import os import subprocess +import sys +import pathlib + import pytest import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION @@ -14,26 +17,34 @@ get_cudnn_version, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils -from test_fused_attn import ModelConfig + +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ModelConfig, get_available_attention_backends + +# Initialize RNG state +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) model_configs_flash_attn = { # test: b, h, hg, d, sq, skv, p, mask, bias - "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA - "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA - "cp_1_2": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) - ), # MHA - "cp_1_3": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) - ), # MHA - "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA + "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA + "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA + "cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA + "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA + "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA "cp_2_2": ModelConfig( - 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) - ), # GQA - "cp_2_3": ModelConfig( - 2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) + 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) ), # GQA + "cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA + "cp_3_0": ModelConfig(2, 4096, 12, 192, attn_mask_type="causal", head_dim_v=128), # MLA + "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA + "cp_3_2": ModelConfig( + 2, 4096, 12, 192, attn_mask_type="causal", window_size=(512, 0), head_dim_v=128 + ), # MLA + "cp_3_3": ModelConfig(2, 4096, 12, 192, window_size=(512, 512), head_dim_v=128), # MLA } @@ -45,7 +56,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): "--nproc-per-node=" + str(num_gpus_per_node), ] te_path = os.getenv("TE_PATH", "/opt/transformerengine") - script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py") + script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py") args.append(script_path) for k, v in kwargs.items(): args.append(f"{k}={v}") @@ -79,6 +90,11 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) + if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: + pytest.skip("MLA CP currently only support KV P2P!") + if IS_HIP_EXTENSION: + if config.head_dim_qk != config.head_dim_v and not FlashAttentionUtils.v3_is_installed: + pytest.skip("MLA FlashAttention requires v3+!") subprocess.run( get_bash_arguments( @@ -94,32 +110,36 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): model_configs_fused_attn = { # test: b, h, hg, d, sq, skv, p, mask, bias - "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA - "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA - "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA - "cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA - "cp_1_4": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA + "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA + "cp_1_2": ModelConfig( + 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA - "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA - "cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA - "cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA + "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA + "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA + "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA + "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA + "cp_2_2": ModelConfig( + 2, + 4096, + 12, + 128, + num_gqa_groups=2, + attn_mask_type="causal", + attn_bias_type="post_scale_bias", + ), # GQA + "cp_2_3": ModelConfig( + 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" + ), # GQA "cp_2_4": ModelConfig( - 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) ), # GQA - "cp_3_0": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64 - ), # MLA - "cp_3_1": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64 - ), # MLA + "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA + "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA "cp_3_2": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64 - ), # MLA - "cp_3_3": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64 + 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA + "cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA } @@ -178,6 +198,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("MLA CP currently only support KV P2P!") if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently does not support FP8 attention!") + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtypes[dtype], + qkv_layout="_".join([qkv_format] * 3), + window_size=config.window_size, + context_parallel=True, + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("No attention backend available.") subprocess.run( get_bash_arguments( diff --git a/tests/pytorch/attention/test_cp_utils.py b/tests/pytorch/attention/test_cp_utils.py new file mode 100644 index 000000000..00200c62d --- /dev/null +++ b/tests/pytorch/attention/test_cp_utils.py @@ -0,0 +1,715 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Unit tests for context parallel utils.""" +import torch +import unittest +from typing import Tuple +from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( + get_batch_on_this_cp_rank, + pad_thd_sequences_for_cp, + generate_positional_ids_for_cp, +) + + +class TestSequencePadding(unittest.TestCase): + def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor(self): + """Test with custom padding values for all tensors.""" + # Setup + + input_ids = torch.tensor([1, 1, 1, 2, 2, 3, 3, 3, 3]) + cu_seqlens = torch.tensor([0, 3, 5, 9]) + labels = torch.tensor([-100, -100, -100, -100, -100, -100, -100, 13, -100]) + positional_ids = torch.tensor([0, 1, 2, 0, 1, 0, 1, 2, 3]) + divisibility_factor = 8 + + pid = 777 + label_pad = -200 + + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Sequence: [ a a a p p p p p b b pppppp ccccpppp] + print("input_ids_padded: ", input_ids_padded) + print("labels_padded: ", labels_padded) + print("positional_ids_padded: ", positional_ids_padded) + print("cu_seqlens_padded: ", cu_seqlens_padded) + + expected_input_ids = torch.tensor( + [ + 1, + 1, + 1, + pid, + pid, + pid, + pid, + pid, + 2, + 2, + pid, + pid, + pid, + pid, + pid, + pid, + 3, + 3, + 3, + 3, + pid, + pid, + pid, + pid, + ] + ) + expected_cu_seqlens_padded = torch.tensor([0, 8, 16, 24]) + expected_labels_padded = torch.tensor( + [ + -100, + -100, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + -100, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + -100, + -100, + 13, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + ] + ) + expected_positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7] + ) + + assert torch.equal(input_ids_padded, expected_input_ids) + assert torch.equal(labels_padded, expected_labels_padded) + assert torch.equal(positional_ids_padded, expected_positional_ids) + assert torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded) + + def test_mixed_sequence_lengths_with_divisibility_factor(self): + """Test with sequences both shorter and longer than divisibility factor.""" + # Setup - divisibility factor 6 + # Seq 1: length 2 (shorter than 6, needs 4 padding) + # Seq 2: length 7 (longer than 6, needs 5 padding to reach 12) + # Seq 3: length 4 (shorter than 6, needs 2 padding) + # Seq 4: length 10 (longer than 6, needs 2 padding to reach 12) + + input_ids = torch.tensor( + [1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4] + ) + labels = torch.tensor( + [ + 10, + 11, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 30, + 31, + 32, + 33, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + ] + ) + positional_ids = torch.tensor( + [0, 1, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ) + cu_seqlens = torch.tensor([0, 2, 9, 13, 23]) + divisibility_factor = 6 + + pid = 999 + label_pad = -300 + + # Execute + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Assert + # Seq 1: [1,1] + 4 pads = 6 total + # Seq 2: [2,2,2,2,2,2,2] + 5 pads = 12 total + # Seq 3: [3,3,3,3] + 2 pads = 6 total + # Seq 4: [4,4,4,4,4,4,4,4,4,4] + 2 pads = 12 total + + expected_input_ids = torch.tensor( + [ + 1, + 1, + pid, + pid, + pid, + pid, # Seq 1: 2 + 4 padding + 2, + 2, + 2, + 2, + 2, + 2, + 2, + pid, + pid, + pid, + pid, + pid, # Seq 2: 7 + 5 padding + 3, + 3, + 3, + 3, + pid, + pid, # Seq 3: 4 + 2 padding + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + pid, + pid, # Seq 4: 10 + 2 padding + ] + ) + + expected_labels = torch.tensor( + [ + 10, + 11, + label_pad, + label_pad, + label_pad, + label_pad, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + 30, + 31, + 32, + 33, + label_pad, + label_pad, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + label_pad, + label_pad, + ] + ) + + expected_positional_ids = torch.tensor( + [ + 0, + 1, + 2, + 3, + 4, + 5, # Seq 1 positions continue through padding + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, # Seq 2 positions continue + 0, + 1, + 2, + 3, + 4, + 5, # Seq 3 positions continue + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, # Seq 4 positions continue + ] + ) + + expected_cu_seqlens_padded = torch.tensor([0, 6, 18, 24, 36]) + + self.assertTrue(torch.equal(input_ids_padded, expected_input_ids)) + self.assertTrue(torch.equal(labels_padded, expected_labels)) + self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids)) + self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded)) + + def test_sequences_longer_than_divisibility_factor(self): + """Test with all sequences longer than the divisibility factor.""" + # Setup - divisibility factor 4, all sequences longer than 4 + # Seq 1: length 7 (needs 1 padding to reach 8) + # Seq 2: length 11 (needs 1 padding to reach 12) + # Seq 3: length 5 (needs 3 padding to reach 8) + + input_ids = torch.tensor( + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, # 7 tokens + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, # 11 tokens + 3, + 3, + 3, + 3, + 3, # 5 tokens + ] + ) + labels = torch.tensor( + [ + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 300, + 301, + 302, + 303, + 304, + ] + ) + positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 1, 2, 3, 4] + ) + cu_seqlens = torch.tensor([0, 7, 18, 23]) + divisibility_factor = 4 + + pid = 888 + label_pad = -400 + + # Execute + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Assert + # Seq 1: 7 + 1 pad = 8 (divisible by 4) + # Seq 2: 11 + 1 pad = 12 (divisible by 4) + # Seq 3: 5 + 3 pads = 8 (divisible by 4) + + expected_input_ids = torch.tensor( + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + pid, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + pid, + 3, + 3, + 3, + 3, + 3, + pid, + pid, + pid, + ] + ) + + expected_labels = torch.tensor( + [ + 100, + 101, + 102, + 103, + 104, + 105, + 106, + label_pad, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + label_pad, + 300, + 301, + 302, + 303, + 304, + label_pad, + label_pad, + label_pad, + ] + ) + + expected_positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7] + ) + + expected_cu_seqlens_padded = torch.tensor([0, 8, 20, 28]) + + self.assertTrue(torch.equal(input_ids_padded, expected_input_ids)) + self.assertTrue(torch.equal(labels_padded, expected_labels)) + self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids)) + self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded)) + + +class TestContextParallelUtils(unittest.TestCase): + """Test utilities for context parallel functionality.""" + + def setUp(self): + """Set up mock distributed environment.""" + # Mock torch.distributed functions + self.original_get_world_size = torch.distributed.get_world_size + self.original_get_rank = torch.distributed.get_rank + + def tearDown(self): + """Restore original torch.distributed functions.""" + torch.distributed.get_world_size = self.original_get_world_size + torch.distributed.get_rank = self.original_get_rank + + def _mock_distributed_env(self, cp_size, cp_rank): + """Mock the distributed environment for testing.""" + + def mock_get_world_size(group=None): + return cp_size + + def mock_get_rank(group=None): + return cp_rank + + torch.distributed.get_world_size = mock_get_world_size + torch.distributed.get_rank = mock_get_rank + + def test_cp_rank_slicing_simple_case(self): + """Test CP rank slicing with a simple 2-rank, single sequence case.""" + # Setup: Single sequence of length 8, CP size = 2 + # Each sequence gets divided into 2*cp_size = 4 slices of size 2 each + # Rank 0 gets slices [0,1] and [6,7] (first and last) + # Rank 1 gets slices [2,3] and [4,5] (second and second-to-last) + + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) # Shape: (1, 8) - batch first + labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]]) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) # Shape: (8,) - 1D as expected + cu_seqlens = torch.tensor([0, 8]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 0 should get indices [0,1] and [6,7] + expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8]]) + expected_labels_r0 = torch.tensor([[10, 20, 70, 80]]) + expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + # Test rank 1 + self._mock_distributed_env(cp_size=2, cp_rank=1) + input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 1 should get indices [2,3] and [4,5] + expected_input_ids_r1 = torch.tensor([[3, 4, 5, 6]]) + expected_labels_r1 = torch.tensor([[30, 40, 50, 60]]) + expected_pos_ids_r1 = torch.tensor([2, 3, 4, 5]) + + self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1)) + self.assertTrue(torch.equal(labels_r1, expected_labels_r1)) + self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1)) + + def test_cp_rank_slicing_multiple_sequences(self): + """Test CP rank slicing with multiple sequences.""" + # Setup: Two sequences of length 8 each, CP size = 2 + # Total sequence length = 16, cu_seqlens = [0, 8, 16] + + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18]]) + labels = torch.tensor( + [[10, 20, 30, 40, 50, 60, 70, 80, 110, 120, 130, 140, 150, 160, 170, 180]] + ) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]) + cu_seqlens = torch.tensor([0, 8, 16]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # For each sequence, rank 0 gets first and last slices + # Seq 1: indices [0,1] and [6,7] -> values [1,2] and [7,8] + # Seq 2: indices [8,9] and [14,15] -> values [11,12] and [17,18] + expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8, 11, 12, 17, 18]]) + expected_labels_r0 = torch.tensor([[10, 20, 70, 80, 110, 120, 170, 180]]) + expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7, 0, 1, 6, 7]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + def test_cp_rank_slicing_with_cp_size_1(self): + """Test that CP size = 1 returns original tensors unchanged.""" + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) + labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]]) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + cu_seqlens = torch.tensor([0, 8]) + + self._mock_distributed_env(cp_size=1, cp_rank=0) + input_ids_result, labels_result, pos_ids_result = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # With CP size = 1, should return original tensors + self.assertTrue(torch.equal(input_ids_result, input_ids)) + self.assertTrue(torch.equal(labels_result, labels)) + self.assertTrue(torch.equal(pos_ids_result, position_ids)) + + def test_cp_rank_slicing_sequence_dim_detection(self): + """Test that the function correctly detects sequence dimension.""" + # Test with sequence dimension = 0 (sequence_length, batch_size) + input_ids = torch.tensor( + [[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]] + ) # (8, 2) + labels = torch.tensor( + [[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]] + ) + position_ids = torch.tensor( + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]] + ) + cu_seqlens = torch.tensor([0, 8]) + + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Should get indices [0,1] and [6,7] along dimension 0 + expected_input_ids_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]]) + expected_labels_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]]) + expected_pos_ids_r0 = torch.tensor([[0, 0], [1, 1], [6, 6], [7, 7]]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + def test_cp_rank_slicing_mixed_dimensions(self): + """Test CP rank slicing where input_ids/labels are 1D but position_ids has batch dimension.""" + # Setup: Single sequence of length 8, CP size = 2 + # This tests the opposite case from the simple test: + # - input_ids and labels: 1D (no batch dimension) + # - position_ids: 2D (has batch dimension) + + input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # Shape: (8,) - 1D + labels = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80]) # Shape: (8,) - 1D + position_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) # Shape: (1, 8) - 2D with batch + cu_seqlens = torch.tensor([0, 8]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 0 should get indices [0,1] and [6,7] + expected_input_ids_r0 = torch.tensor([1, 2, 7, 8]) # 1D result + expected_labels_r0 = torch.tensor([10, 20, 70, 80]) # 1D result + expected_pos_ids_r0 = torch.tensor([[0, 1, 6, 7]]) # 2D result (preserves batch dim) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + # Test rank 1 + self._mock_distributed_env(cp_size=2, cp_rank=1) + input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 1 should get indices [2,3] and [4,5] + expected_input_ids_r1 = torch.tensor([3, 4, 5, 6]) # 1D result + expected_labels_r1 = torch.tensor([30, 40, 50, 60]) # 1D result + expected_pos_ids_r1 = torch.tensor([[2, 3, 4, 5]]) # 2D result (preserves batch dim) + + self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1)) + self.assertTrue(torch.equal(labels_r1, expected_labels_r1)) + self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1)) + + def test_integration_with_padding_and_cp_slicing(self): + """Integration test: pad sequences then slice for CP ranks.""" + # Start with unpadded sequences + input_ids = torch.tensor([1, 1, 2, 2, 2]) # Two sequences: [1,1] and [2,2,2] + labels = torch.tensor([10, 11, 20, 21, 22]) + positional_ids = torch.tensor([0, 1, 0, 1, 2]) + cu_seqlens = torch.tensor([0, 2, 5]) + divisibility_factor = 4 # Will pad to lengths 4 and 4 + + # First, pad sequences + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=0, + padding_label_id=-100, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Expected after padding: [1,1,0,0,2,2,2,0] with cu_seqlens [0,4,8] + expected_padded = torch.tensor([1, 1, 0, 0, 2, 2, 2, 0]) + self.assertTrue(torch.equal(input_ids_padded, expected_padded)) + + # Now test CP slicing with cp_size=2 + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens_padded, + input_ids_padded.unsqueeze(0), + labels_padded.unsqueeze(0), + positional_ids_padded, + ) + + # Each sequence of length 4 gets divided into 4 slices of size 1 + # Rank 0 gets slices [0] and [3] from each sequence + # Seq 1: indices [0] and [3] -> values [1] and [0] + # Seq 2: indices [4] and [7] -> values [2] and [0] + expected_input_ids_r0 = torch.tensor([[1, 0, 2, 0]]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pytorch/fused_attn/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py similarity index 95% rename from tests/pytorch/fused_attn/test_kv_cache.py rename to tests/pytorch/attention/test_kv_cache.py index 967309459..af71866f3 100644 --- a/tests/pytorch/fused_attn/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -5,19 +7,16 @@ from collections import OrderedDict from typing import List import os +import sys +import pathlib import logging import math import pytest import torch -from test_fused_attn import ( - ModelConfig, - reset_rng_states, - _get_attention_backends, -) - from torch.distributions import Exponential +from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch import make_graphed_callables from transformer_engine.common import recipe from transformer_engine.pytorch import fp8_autocast, fp8_model_init @@ -34,26 +33,25 @@ is_bf16_compatible, ) -# Initialize RNG state -seed = 1234 -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ( + ModelConfig, + reset_rng_states, + get_available_attention_backends, +) +# Reset RNG states +reset_rng_states() param_types = [torch.float16] if is_bf16_compatible(): param_types.append(torch.bfloat16) model_configs_infer = { - # test: b, h, hg, d, sq, skv, p, mask, bias - "infer_0": ModelConfig( - 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 - ), - "infer_1": ModelConfig( - 2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 - ), + # test: b, sq, hq, dqk, + "infer_0": ModelConfig(4, 64, 16, 128, total_requests=8, max_ctx_len=16), + "infer_1": ModelConfig(2, 66, 16, 256, num_gqa_groups=4, total_requests=6, max_ctx_len=16), } qkv_formats = ["bshd", "sbhd", "thd"] @@ -388,6 +386,12 @@ def get_tols(config, module, backend, dtype): torch.half: (1e-2, 1e-2), torch.bfloat16: (8e-2, 7e-2), } + # With FA on ROCm it may not fit default tolerance + if IS_HIP_EXTENSION and backend == "FlashAttention": + tols = { + torch.half: (1e-2, 1e-2), + torch.bfloat16: (1e-1, 1e-1), + } if module == "DotProductAttention": tols = { torch.half: (1e-3, 1e-3), @@ -406,6 +410,11 @@ def get_tols(config, module, backend, dtype): @pytest.mark.parametrize("is_cuda_graph", [False, True]) @pytest.mark.parametrize("is_fp8", [False, True]) def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8): + if IS_HIP_EXTENSION: + if is_paged and backend == "FusedAttention": + pytest.skip("Paged KV cache is not supported for FusedAttention on ROCm") + if qkv_format == "thd" and backend == "FusedAttention": + pytest.skip("THD KV cache is not supported for FusedAttention on ROCm") reset_rng_states() logger = logging.getLogger("test_kv_cache") fp8_recipe = recipe.DelayedScaling( @@ -470,7 +479,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2) if is_paged: qkv_layout = "paged_kv_" + qkv_layout - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/tests/pytorch/debug/run_distributed.py b/tests/pytorch/debug/run_distributed.py index b12f8c3d3..716c16056 100644 --- a/tests/pytorch/debug/run_distributed.py +++ b/tests/pytorch/debug/run_distributed.py @@ -364,6 +364,40 @@ def get_stat(tensor, stat): set_weight_tensor_tp_group_reduce(True) # reset +@run_debug_test +def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs): + from test_log import LOG_QUANTIZED_CONFIG + + kwargs["config_file"].write(LOG_QUANTIZED_CONFIG) + kwargs["config_file"].flush() + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + set_weight_tensor_tp_group_reduce(gather_weight) + if WORLD_SIZE % 2 != 0: + return # skip + TP_SIZE = WORLD_SIZE // 2 + DP_SIZE = 2 + TP_RANK = WORLD_RANK % TP_SIZE + DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE + + debug_api.set_tensor_reduction_group(NCCL_WORLD) + + x, weight = _get_tensors( + parallel_mode, + weight_seed=TP_RANK * 1234, + data_seed=DP_RANK * 1234, + tp_size=TP_SIZE, + tp_rank=TP_RANK, + ) + + tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)] + tp_group = dist.new_group(ranks=tp_group_ranks) + + model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group) + _run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group) + + set_weight_tensor_tp_group_reduce(True) # reset + + @run_debug_test def test_log_expert_parallel(**kwargs): """ diff --git a/tests/pytorch/debug/test_api_features.py b/tests/pytorch/debug/test_api_features.py index f9cd234ba..d28db1647 100644 --- a/tests/pytorch/debug/test_api_features.py +++ b/tests/pytorch/debug/test_api_features.py @@ -24,22 +24,17 @@ def test_transformer_engine_no_config(feature_dirs): # FP8 enabled - true by the default assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="fprop", iteration=0 - ) + )[0] - # modify_tensor_enabled - False by default + # modify_tensor_enabled - (False, None) by default assert not debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0 - ) + )[0] - # inspect_tensor_enabled - False by default + # inspect_tensor_enabled - (False, None) by default assert not debug_api.transformer_engine.inspect_tensor_enabled( "decoder.1.attn.qkv", tensor_name="activation", iteration=0 - ) - - # inspect_tensor_postquantize - False by default - assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( - "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0 - ) + )[0] finally: debug_api.end_debug() @@ -51,24 +46,24 @@ def test_disable_fp8_gemm(configs_dir, feature_dirs): assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="fprop", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="dgrad", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="wgrad", iteration=0 - ) + )[0] # caching assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="fprop", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="dgrad", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="wgrad", iteration=0 - ) + )[0] finally: debug_api.end_debug() @@ -80,22 +75,22 @@ def test_disable_fp8_layer(configs_dir, feature_dirs): assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.mlp.fc1", gemm="fprop", iteration=0 - ) + )[0] assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.mlp.fc1", gemm="wgrad", iteration=0 - ) + )[0] assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.mlp.fc1", gemm="dgrad", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="fprop", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="wgrad", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="dgrad", iteration=0 - ) + )[0] finally: debug_api.end_debug() @@ -111,22 +106,22 @@ def test_per_tensor_scaling(configs_dir, feature_dirs): # check modify_tensor_enabled assert debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0 - ) + )[0] assert debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0 - ) + )[0] assert debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0 - ) + )[0] # check modify_tensor @@ -168,14 +163,14 @@ def test_per_tensor_scaling(configs_dir, feature_dirs): gemm="wgrad", tensor_name="gradient", iteration=0, - ) + )[0] assert not debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc4", gemm="fprop", tensor_name="activation", iteration=0, - ) + )[0] finally: debug_api.end_debug() @@ -191,11 +186,11 @@ def test_fake_quant(configs_dir, feature_dirs): # modify_tensor_enabled assert debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0 - ) + )[0] assert debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0 - ) + )[0] # modify_tensor debug_api.transformer_engine.modify_tensor( @@ -218,11 +213,11 @@ def test_fake_quant(configs_dir, feature_dirs): assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.fc2", gemm="wgrad", iteration=0 - ) + )[0] # caching assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.fc2", gemm="wgrad", iteration=0 - ) + )[0] finally: debug_api.end_debug() @@ -236,13 +231,12 @@ def test_statistics_collection(configs_dir, feature_dirs): ) tensor = torch.randn((100, 100, 5)).cuda() - tensor_fp8 = Float8Tensor( - data=tensor.to(torch.uint8).cuda(), - fp8_scale_inv=torch.full([1], 1.0).cuda(), + quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda(), + amax=torch.full([1], 1.0).cuda(), fp8_dtype=tex.DType.kFloat8E4M3, - shape=tensor.shape, - dtype=torch.float32, ) + tensor_fp8 = quantizer(tensor) def log(): from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS @@ -260,75 +254,90 @@ def assert_empty(): tensor_name="activation", iteration=200, tp_group=None, + quantizer=quantizer, + rowwise_quantized_tensor=tensor_fp8, + columnwise_quantized_tensor=tensor_fp8, ) stats = log() assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max() assert not debug_api.transformer_engine.inspect_tensor_enabled( "decoder.1.mlp.fc1", tensor_name="activation", iteration=201 - ) + )[0] assert not debug_api.transformer_engine.inspect_tensor_enabled( "decoder.2.mlp.fc1", tensor_name="activation", iteration=200 - ) - assert not debug_api.transformer_engine.inspect_tensor_enabled( - "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200 + )[0] + + expected_underflows = ( + ((tensor_fp8.dequantize() == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5) ) - expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5) - expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5) + assert debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200 + )[0] # TE FP8 tensor stats -- - assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled( - "decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 - ) - debug_api.transformer_engine.inspect_tensor_postquantize( + assert debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200 + )[0] + debug_api.transformer_engine.inspect_tensor( "decoder.1.mlp.fc1", - tensor=tensor_fp8, tensor_name="gradient", iteration=200, - rowwise=True, tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=tensor_fp8, + columnwise_quantized_tensor=tensor_fp8, ) stats = log() torch.testing.assert_close( stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows ) - assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( - "decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201 - ) - assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( - "decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 - ) + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.1.mlp.fc1", tensor_name="activation", iteration=201 + )[0] + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.2.mlp.fc1", tensor_name="gradient", iteration=200 + )[0] # Second config in same yaml - tensor = torch.rand((100, 100, 5)) + tensor = torch.rand((100, 100, 5)).cuda() debug_api.transformer_engine.inspect_tensor( "decoder.6.mlp.fc1", - tensor=tensor, tensor_name="activation", iteration=200, tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=tensor_fp8, + columnwise_quantized_tensor=tensor_fp8, ) stats = log() stats_names = [x[3] for x in stats.keys()] all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"]) - assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean() + torch.testing.assert_close( + stats[("decoder.6.mlp.fc1", "activation", "mean", 200)], tensor.mean() + ) debug_api.transformer_engine.inspect_tensor( "decoder.7.mlp.fc1", - tensor=tensor, tensor_name="weight", iteration=200, tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=tensor_fp8, + columnwise_quantized_tensor=tensor_fp8, ) stats = log() stats_names = [x[3] for x in stats.keys()] all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"]) - assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max() + torch.testing.assert_close(stats[("decoder.7.mlp.fc1", "weight", "max", 200)], tensor.max()) assert not debug_api.transformer_engine.inspect_tensor_enabled( "decoder.7.mlp.fc1", tensor_name="weight", iteration=201 - ) + )[0] assert_empty() finally: @@ -343,21 +352,16 @@ def test_statistics_multi_run(configs_dir, feature_dirs): default_logging_enabled=False, ) - def feed(tensor, tensor_fp8): + def feed(tensor, tensor_fp8, quantizer): debug_api.transformer_engine.inspect_tensor( "decoder.5.mlp.fc1", tensor=tensor, tensor_name="activation", iteration=1, tp_group=None, - ) - debug_api.transformer_engine.inspect_tensor_postquantize( - "decoder.5.mlp.fc1", - tensor=tensor_fp8, - tensor_name="activation", - iteration=1, - rowwise=True, - tp_group=None, + quantizer=quantizer, + rowwise_quantized_tensor=tensor_fp8, + columnwise_quantized_tensor=tensor_fp8, ) def log_stats(): @@ -365,26 +369,26 @@ def log_stats(): return STATS_BUFFERS.log_stats() + quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + def fp8_tensor(t): - return Float8Tensor( - data=t.to(torch.uint8).cuda(), - fp8_scale_inv=torch.ones([1]).cuda(), - fp8_dtype=tex.DType.kFloat8E4M3, - shape=t.shape, - dtype=torch.float32, - ) + return quantizer(t.cuda()) shape = [1024, 1024] - tensors = [torch.randn(shape) for _ in range(2)] + tensors = [torch.randn(shape).cuda() for _ in range(2)] tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)] - feed(tensors[0], tensors_fp8[0]) - feed(tensors[1], tensors_fp8[1]) + feed(tensors[0], tensors_fp8[0], quantizer) + feed(tensors[1], tensors_fp8[1], quantizer) stats1 = log_stats() tensor2 = torch.cat((tensors[0], tensors[1])).cuda() fp8tensor2 = fp8_tensor(tensor2) - feed(tensor2, fp8tensor2) + feed(tensor2, fp8tensor2, quantizer) stats2 = log_stats() assert len(stats1.keys()) > 0 diff --git a/tests/pytorch/debug/test_configs/log_config.yaml b/tests/pytorch/debug/test_configs/log_config.yaml new file mode 100644 index 000000000..3e94006d9 --- /dev/null +++ b/tests/pytorch/debug/test_configs/log_config.yaml @@ -0,0 +1,19 @@ +test: + enabled: True + layers: + layer_name_regex_pattern: .* + transformer_engine: + LogTensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [cur_amax, dynamic_range, mean, std, l1_norm] + start_step: 1 + freq: 3 + LogFp8TensorStats: + enabled: True + tensors: activation + stats: [underflows%] + start_step: 1 + freq: 5 + \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/perf_config.yaml b/tests/pytorch/debug/test_configs/perf_config.yaml new file mode 100644 index 000000000..4ef5e51cd --- /dev/null +++ b/tests/pytorch/debug/test_configs/perf_config.yaml @@ -0,0 +1,13 @@ +test: + enabled: True + layers: + layer_name_regex_pattern: .*1 + transformer_engine: + LogTensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [cur_amax, dynamic_range, mean, std, l1_norm] + start_step: 0 + freq: 100000 + \ No newline at end of file diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py new file mode 100644 index 000000000..dcc9861c8 --- /dev/null +++ b/tests/pytorch/debug/test_log.py @@ -0,0 +1,254 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import nvdlfw_inspect.api as debug_api +import transformer_engine.debug +import transformer_engine.pytorch as te +import torch +import tempfile +from transformer_engine.common import recipe +from transformer_engine.pytorch.fp8 import RecipeState +import pytest +import contextlib +import os +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.debug.pytorch.debug_state import TEDebugState + + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) + +LOG_QUANTIZED_CONFIG_BASE = """ +log: + layers: + layer_name_regex_pattern: .* + enabled: + True + transformer_engine: + LogFp8TensorStats: + enabled: True + stats: [ + {stats} + ] + tensors: [activation, gradient, weight] + freq: 2 + start_step: 0 + end_step: 10 +""" +recipes = [ + "fp8_delayed_scaling", + "fp8_current_scaling", + "fp8_block_scaling", + "mxfp8", +] + +bare_stats = [ + "underflows%", + "scale_inv_min", + "scale_inv_max", + "mse", +] + +all_stats = [] + +for r in recipes: + for stat in bare_stats: + for columnwise_postfix in ["", "_columnwise"]: + if ( + r in ["fp8_current_scaling", "fp8_block_scaling"] + and torch.cuda.get_device_capability()[0] < 9 + ): + # hopper is needed for current-scaling, block-scaling + continue + if r == "mxfp8" and torch.cuda.get_device_capability()[0] < 10: + # blackwell is needed for mxfp8 + continue + if ( + r in ["fp8_delayed_scaling", "fp8_current_scaling"] + and columnwise_postfix == "_columnwise" + ): + # columnwise stats are not supported for fp8_delayed_scaling and fp8_current_scaling + continue + + all_stats.append(f"{r}_{stat}{columnwise_postfix}") + +all_stats.append("fp8_delayed_scaling_overflows%") # only delayed-scaling supports overflows% + + +@contextlib.contextmanager +def debug_session(config_str: str, feature_dirs): + """ + Helper context manager that + 1. writes the YAML `config_str` to a temporary file, + 2. starts a debug session, and + 3. yields the directory that contains the statistics log. + + The session is closed automatically – even on exceptions – so every test + stays concise and leak-free. + """ + with tempfile.NamedTemporaryFile( + mode="w", delete=False + ) as cfg_file, tempfile.TemporaryDirectory() as log_dir: + cfg_file.write(config_str) + cfg_file.flush() + + debug_api.initialize( + config_file=cfg_file.name, + feature_dirs=feature_dirs, + log_dir=log_dir, + ) + try: + yield log_dir + finally: + debug_api.end_debug() + + +def read_log(log_dir: str) -> str: + """Return the content of the statistics log produced by `debug_session`.""" + stat_path = os.path.join( + log_dir, + "nvdlfw_inspect_statistics_logs", + "nvdlfw_inspect_globalrank-0.log", + ) + with open(stat_path, "r") as f: + return f.read() + + +def test_sanity(feature_dirs): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats)) + with debug_session(log_all_stats_config, feature_dirs) as log_dir: + model = te.Linear(128, 128, params_dtype=torch.bfloat16) + inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda() + + for _ in range(10): + with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()): + output = model(inp) + loss = output.sum() + loss.backward() + debug_api.step() + + output = read_log(log_dir) + + assert output, "Output is empty" + for stat in all_stats: + assert stat in output, f"Stat {stat} not found in output" + + +fp8_recipes = [ + recipe.MXFP8BlockScaling(), + recipe.DelayedScaling(), + recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), +] + + +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +def test_numerics(fp8_recipe, feature_dirs): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling(): + pytest.skip(reason_for_no_mxfp8) + if not fp8_block_scaling_available and fp8_recipe == recipe.Float8BlockScaling(): + pytest.skip(reason_for_no_fp8_block_scaling) + + log_only_bare_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(bare_stats)) + + with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir: + recipe_state = RecipeState.create( + fp8_recipe, + mode="forward", + num_quantizers=3, + ) + + tensor = torch.randn(1024, 1024).cuda() + tensor[0, 100:200] = -0.0 + quantizer = recipe_state.make_quantizers()[0] + quantized_tensor = quantizer(tensor) + + debug_api.transformer_engine.inspect_tensor( + layer_name="layer_name", + tensor_name="activation", + iteration=0, + tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=quantized_tensor, + columnwise_quantized_tensor=quantized_tensor, + ) + debug_api.step() + + dequantized_tensor = quantized_tensor.dequantize() + output = read_log(log_dir) + + for line in output.splitlines(): + if "underflows%" in line: + underflows = float(line.split("value=")[1]) + expected = ( + ((dequantized_tensor == 0).sum() - (tensor == 0).sum()) / tensor.numel() * 100 + ) + assert underflows == pytest.approx(expected.cpu(), abs=1e-4) + if "mse" in line: + mse = float(line.split("value=")[1]) + expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean") + assert mse == pytest.approx(expected.cpu(), abs=1e-4) + if "overflows%" in line: + overflows = float(line.split("value=")[1]) + expected = ( + (abs(dequantized_tensor) > abs(tensor)).sum() / dequantized_tensor.numel() * 100 + ) + assert overflows == pytest.approx(expected.cpu(), abs=1e-4) + + +@pytest.mark.parametrize("layer", ["linear", "transformer"]) +def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + # If layer does not invoke any feature in current iteration, + # then it changed into non-debug mode. + # This test checks whether this works correctly - + # non-quantized statistics should be logged every 3 iterations, + # and quantized statistics should be logged every 5 iterations. + with tempfile.TemporaryDirectory() as temp_dir: + debug_api.initialize( + config_file=configs_dir + "/log_config.yaml", + feature_dirs=feature_dirs, + log_dir=temp_dir, + ) + + if layer == "linear": + model = te.Linear(128, 128, name="linear1") + elif layer == "transformer": + model = te.TransformerLayer(128, 128, 4, name="transformer1") + else: + raise ValueError(f"Invalid layer: {layer}") + + for i in range(20): + x = torch.randn(4, 128, 128).cuda() + with te.fp8_autocast(enabled=True): + y = model(x) + y.sum().backward() + debug_api.step() + + with open( + os.path.join( + temp_dir, "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log" + ), + "r", + ) as f: + file_content = f.read() + for i in range(1, 20): + if i % 3 == 0 or i % 5 == 0: + assert f"iteration={i:06d}" in file_content + else: + assert f"iteration={i:06d}" not in file_content + + debug_api.end_debug() + TEDebugState._reset() diff --git a/tests/pytorch/debug/test_perf.py b/tests/pytorch/debug/test_perf.py new file mode 100644 index 000000000..2d4b62b23 --- /dev/null +++ b/tests/pytorch/debug/test_perf.py @@ -0,0 +1,76 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + +import pytest +import torch +import transformer_engine.pytorch as te +import time + +import nvdlfw_inspect.api as debug_api + +from transformer_engine.debug.pytorch.debug_state import TEDebugState + + +def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs): + debug_api.end_debug() + TEDebugState._reset() + if debug_tools_initialized: + # This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS. + # So after 1 warm-up iteration, this layers should work in non-debug mode. + debug_api.initialize( + config_file=configs_dir + "/perf_config.yaml", feature_dirs=feature_dirs + ) + + try: + if layer == "linear": + model = torch.nn.Sequential( + te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2") + ).cuda() + NUM_ITERS = 18000 + elif layer == "transformer": + model = torch.nn.Sequential( + te.TransformerLayer(1, 1, 1, name="transformer1"), + te.TransformerLayer(1, 1, 1, name="transformer2"), + ).cuda() + NUM_ITERS = 2000 + + x = torch.randn(1, 1, 1).cuda() + + y = model(x) + y.sum().backward() + debug_api.step() + torch.cuda.synchronize() + + time_start = time.time() + for i in range(NUM_ITERS): + y = model(x) + y.sum().backward() + if debug_tools_initialized: + debug_api.step() + torch.cuda.synchronize() + time_end = time.time() + + finally: + if debug_tools_initialized: + debug_api.end_debug() + + return time_end - time_start + + +@pytest.mark.parametrize("layer", ["linear", "transformer"]) +def test_cpu_overhead(layer, configs_dir, feature_dirs): + # runs one layer many times on very small tensor + # - gpu time should be negligible, so time should be dominated by cpu time. + # if layers does not invoke any feature in current iteration, + # then it changed into non-debug mode and should not have any non-negligible cpu overhead + # compared to layer without debug tools initialized. + + with_debug_tools = _run_cpu_overhead(True, layer, configs_dir, feature_dirs) + without_debug_tools = _run_cpu_overhead(False, layer, configs_dir, feature_dirs) + + print(f"with_debug_tools: {with_debug_tools} s") + print(f"without_debug_tools: {without_debug_tools} s") + + assert with_debug_tools < without_debug_tools * 1.25 # 25% overhead margin diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 8638c1bce..2a6e55b2c 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -12,6 +12,8 @@ import warnings import pprint import yaml +from contextlib import nullcontext +from functools import partial import torch import torch.distributed as dist @@ -35,9 +37,10 @@ def __init__(self, module, num_layers, *args, **kwargs): self.num_layers = num_layers self.layers = torch.nn.ModuleList([module(*args, **kwargs) for _ in range(num_layers)]) - def forward(self, x): - for layer in self.layers: - x = layer(x) + def forward(self, x, layer_contexts): + for layer, context in zip(self.layers, layer_contexts): + with context(): + x = layer(x) return x @@ -237,12 +240,46 @@ def _parse_args(argv=None, namespace=None): default=False, help="Print out additional debug information.", ) + parser.add_argument( + "--first-last-layers-bf16", + action="store_true", + default=False, + help="Use bf16 for first and last N layers.", + ) + parser.add_argument( + "--num-layers-at-start-in-bf16", + type=int, + default=0, + help="Number of layers at the start to run in bf16.", + ) + parser.add_argument( + "--num-layers-at-end-in-bf16", + type=int, + default=0, + help="Number of layers at the end to run in bf16.", + ) args = parser.parse_args(argv, namespace) if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]: warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!") args.use_cuda_graphs = False + if not args.first_last_layers_bf16 and ( + args.num_layers_at_start_in_bf16 > 0 or args.num_layers_at_end_in_bf16 > 0 + ): + warnings.warn( + "num-layers-at-start-in-bf16 and num-layers-at-end-in-bf16 are only supported when" + " first-last-layers-bf16 is enabled!" + ) + args.num_layers_at_start_in_bf16 = 0 + args.num_layers_at_end_in_bf16 = 0 + + if args.num_layers_at_start_in_bf16 + args.num_layers_at_end_in_bf16 > args.num_layers: + raise ValueError( + "num-layers-at-start-in-bf16 + num-layers-at-end-in-bf16 must be less than or equal to" + " num-layers!" + ) + return args @@ -381,10 +418,21 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): "qkv_dgrad": {"method": "ring_exchange"}, "fc1_dgrad": {"method": "ring_exchange"}, } + + quantization_modes = [ + ( + te.module.base.UserBufferQuantizationMode.FP8 + if opts.fp8 + else te.module.base.UserBufferQuantizationMode.NONE + ) + ] + if opts.first_last_layers_bf16 and opts.fp8: + quantization_modes.append(te.module.base.UserBufferQuantizationMode.NONE) + te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], opts.tp, - use_fp8=opts.fp8, + quantization_modes=quantization_modes, dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg, @@ -423,6 +471,16 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): elif opts.quantization == "mxfp8": fp8_recipe = MXFP8BlockScaling() + layer_contexts = [ + ( + partial(te.fp8_autocast, enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world) + if opts.num_layers_at_start_in_bf16 <= i + and i < (opts.num_layers - opts.num_layers_at_end_in_bf16) + else nullcontext + ) + for i in range(opts.num_layers) + ] + # Prepare random input tensors test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True) test_x.retain_grad() @@ -435,14 +493,13 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): # Execute fwd/bwd and collect tensors to test def run_fwd_bwd(model, x): with torch.amp.autocast("cuda", dtype=torch.bfloat16): - with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): - y = model(x) - if isinstance(y, tuple): - out, *_ = y - else: - out = y - loss = out.sum() - loss.backward() + y = model(x, layer_contexts) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + loss = out.sum() + loss.backward() return out torch_rng_state = torch.get_rng_state() @@ -519,6 +576,7 @@ def run_fwd_bwd(model, x): if opts.use_cuda_graphs: del test_graph + torch.cuda.synchronize() te.module.base.destroy_ub() dist_print("Destroying Userbuffers objects...", debug=True) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 37f0e8669..d6ddfe27c 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -506,7 +506,13 @@ def main() -> None: model_config.num_heads * model_config.head_dim, ], torch.distributed.get_world_size(group), - use_fp8=model_config.quantization is not None, + quantization_modes=[ + ( + te.module.base.UserBufferQuantizationMode.FP8 + if model_config.quantization is not None + else te.module.base.UserBufferQuantizationMode.NONE + ) + ], dtype=model_config.dtype, bootstrap_backend=bootstrap_backend, ub_cfgs=userbuffer_configs, diff --git a/tests/pytorch/distributed/test_sanity.py b/tests/pytorch/distributed/test_sanity.py new file mode 100644 index 000000000..39494a92b --- /dev/null +++ b/tests/pytorch/distributed/test_sanity.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pathlib +import sys +import pytest +import torch +import transformer_engine +from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention +from transformer_engine.pytorch import TransformerLayer, Linear + +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ModelConfig + +model_configs = { + "small": ModelConfig(2, 10, 2, 16), +} + + +@pytest.mark.parametrize("model", ["small"]) +@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention", "Linear"]) +def test_current_device(model, module): + """Test cases where current device is different from tensor device""" + + num_devices = torch.cuda.device_count() + assert num_devices > 1, "This test requires more than one GPU!" + tensor_device = num_devices - 1 + dtype = torch.bfloat16 + config = model_configs[model] + + args = [] + kwargs = {} + bwd_args = [] + if module == "TransformerLayer": + model = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_heads, + params_dtype=dtype, + attn_input_format="thd", + self_attn_mask_type="padding", + device=f"cuda:{tensor_device}", + ) + num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item() + args = [ + torch.randn( + (num_tokens, config.hidden_size), + dtype=dtype, + device=f"cuda:{tensor_device}", + requires_grad=True, + ) + ] + cu_seqlens_q, cu_seqlens_kv = [ + torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2) + ] + kwargs["cu_seqlens_q"] = cu_seqlens_q + kwargs["cu_seqlens_kv"] = cu_seqlens_kv + kwargs["max_seqlen_q"] = config.max_seqlen_q + kwargs["max_seqlen_kv"] = config.max_seqlen_kv + if module == "DotProductAttention": + model = DotProductAttention( + config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding" + ) + num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item() + args = [ + torch.randn( + num_tokens, + config.num_heads, + config.head_dim_qk, + dtype=dtype, + device=tensor_device, + requires_grad=True, + ) + for _ in range(3) + ] + cu_seqlens_q, cu_seqlens_kv = [ + torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2) + ] + kwargs["cu_seqlens_q"] = cu_seqlens_q + kwargs["cu_seqlens_kv"] = cu_seqlens_kv + kwargs["max_seqlen_q"] = config.max_seqlen_q + kwargs["max_seqlen_kv"] = config.max_seqlen_kv + bwd_args = [torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=tensor_device)] + elif module == "Linear": + model = Linear( + config.hidden_size, + 4 * config.hidden_size, + params_dtype=dtype, + device=f"cuda:{tensor_device}", + ) + args = [ + torch.randn( + (config.max_seqlen_q, config.batch_size, config.hidden_size), + dtype=dtype, + device=f"cuda:{tensor_device}", + requires_grad=True, + ) + ] + + current_device_before = torch.cuda.current_device() + out = model(*args, **kwargs) + if module == "DotProductAttention": + out.backward(*bwd_args) + else: + loss = out.sum() + loss.backward() + current_device_after = torch.cuda.current_device() + tensor_device_out = out.get_device() + tensor_device_grad = args[0].grad.get_device() + + assert ( + current_device_after == current_device_before + ), "The current device should not have changed!" + assert ( + tensor_device_out == tensor_device + ), "The output tensor should be the same as the input tensors!" + assert ( + tensor_device_grad == tensor_device + ), "The gradient tensor should be the same as the input tensors!" diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 59383f21b..6035a6528 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -2,30 +2,36 @@ # # See LICENSE for license information. +import contextlib +import gc import os -from contextlib import nullcontext +from typing import Iterable, Optional + import pytest import torch import transformer_engine.pytorch as te from transformer_engine.common import recipe from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported +from utils import ModelConfig, get_available_attention_backends -# Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +# Check supported quantization schemes +fp8_available, _ = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() -fp8_recipes = [ - None, # non-fp8 - # recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet - recipe.Float8CurrentScaling(), - recipe.DelayedScaling(), -] +quantization_recipes: Optional[recipe.Recipe] = [None] +if fp8_available: + quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling())) -SIZE = 512 -NUM_HEADS = 8 -NUM_LAYERS = 5 -EPSILON = 0.1 +model_config = { + "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), +} +SIZE = model_config["small"].hidden_size +NUM_HEADS = model_config["small"].num_heads +NUM_LAYERS = model_config["small"].num_layers +EPSILON = model_config["small"].eps # Flash attention saves some internal tensor for the backward pass # that cannot be offloaded to CPU. @@ -46,104 +52,162 @@ "transformer_layer": lambda: te.TransformerLayer( SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 ), + "linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + "layernorm_mlp_ops": lambda: te.ops.Sequential( + te.ops.LayerNorm(SIZE, dtype=torch.bfloat16), + te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + te.ops.GELU(), + te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + ), } -def _get_input(): - return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda() +def _make_input() -> torch.Tensor: + """Generate random input tensor.""" + return torch.randn( + (128, SIZE, SIZE), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) -def _get_fp8_weight_cache_size(models, fp8_recipe): - """ - Calculate the total FP8 weight cache size (in MB) for a list of models. - """ - if fp8_recipe is None: +def _warmup_model( + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], +) -> None: + """Perform forward and backward pass""" + tensor = _make_input() + for module in modules: + with te.fp8_autocast( + enabled=quantization_recipe is not None, + fp8_recipe=quantization_recipe, + ): + tensor = module(tensor) + tensor.sum().backward() + + +def _estimate_cached_weight_size( + model_name: str, + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], +) -> float: + """Calculate the memory (in MiB) needed for weight caching.""" + + # The weight params are cached directly for unquantized compute + if quantization_recipe is None: return 0 - params_bytes = 0 - for model in models: - for name, param in model.named_parameters(): - if "weight" in name: - params_bytes += param.numel() + # Count number of weight param elements + param_elements = 0 + for module in modules: + for param in module.parameters(): + if param.dim() == 2: + param_elements += param.numel() + + # FP8 tensor-scaling caches one byte per element + if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling(): + if not is_non_tn_fp8_gemm_supported() and model_name not in ( + "linear_op", + "layernorm_mlp_ops", + ): + # Modules do not deallocate FP8 transpose for weights + return 2 * param_elements / 1024**2 + return param_elements / 1024**2 + + # MXFP8 caches one data byte per element and one scale byte per 32 + # elements + if quantization_recipe.mxfp8(): + if model_name not in ("linear_op", "layernorm_mlp_ops"): + # Modules do not deallocate column-wise MXFP8 data for weights + return 2 * param_elements * (1 + 1 / 32) / 1024**2 + return param_elements * (1 + 1 / 32) / 1024**2 + + raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})") + + +def _measure_cached_memory( + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], + cpu_offload: bool, +) -> float: + """Measure the growth in allocated GPU memory in MiB after a model forward pass. + + Memory measurement excludes the input and output tensors. - # One byte for columnwise and one byte for rowwise, - # hence multiply by 2 and convert to MB - # there is 1 byte of scale per 32 elements in mxFP8 - factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1 - return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2) + """ + # Reset memory + gc.collect() + torch.cuda.empty_cache() -def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload): - tensor = _get_input() + # Context and sync function for CPU offloading if cpu_offload: offload_context, sync_function = te.get_cpu_offload_context( enabled=True, - num_layers=len(models) - 1, - model_layers=len(models), + num_layers=len(modules), + model_layers=len(modules) + 1, offload_activations=True, offload_weights=False, ) else: - offload_context = nullcontext() + offload_context = contextlib.nullcontext() sync_function = lambda x: x - for model in models: + # Forward pass, with dummy step to trigger offload for last module + inp = _make_input() + tensor = inp + memory_before_forward = torch.cuda.memory_allocated() / (1024**2) + for module in modules: with te.fp8_autocast( - enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe + enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe ), offload_context: - tensor = model(tensor) + tensor = module(tensor) tensor = sync_function(tensor) + with offload_context: + tensor = tensor.clone() + tensor = sync_function(tensor) + memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2) - max_mem_used = torch.cuda.memory_allocated() / (1024**2) - torch.cuda.synchronize() - + # Backward pass tensor.sum().backward() + torch.cuda.synchronize() - return max_mem_used - - -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("model_key", model_types.keys()) -def test_cpu_offload(fp8_recipe, model_key) -> None: - """ - We run three configurations: - (1) No offloading: All activations remain on the GPU between forward and backward passes. - (2) No offloading (one layer): Only the first layer's activations remain on the GPU between - forward and backward passes. - (3) With offloading (all layers): Only the last layer's activations remain on the GPU - between forward and backward passes, while all other layers are offloaded to the CPU. - - We expect the memory consumption of configurations (2) and (3) to be similar, with - the difference being the size of the FP8 cache that is not offloaded to the CPU. - We also expect this memory consumption to be smaller than in scenario (1). - """ - import gc - - gc.collect() - - model_cls = model_types[model_key] - models_list = [model_cls() for _ in range(NUM_LAYERS)] - - if fp8_recipe and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None: - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) + # Memory usage in MiB + return memory_after_forward - memory_before_forward - without_offloading = _measure_memory_between_forward_and_backward( - models_list, fp8_recipe, False - ) - without_offloading_one_layer = _measure_memory_between_forward_and_backward( - models_list[:1], fp8_recipe, False - ) - with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True) - assert with_offloading < without_offloading +@pytest.mark.parametrize("quantization_recipe", quantization_recipes) +@pytest.mark.parametrize("model_name", model_types.keys()) +def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None: + """Check that CPU offloading runs and has expected memory usage.""" - # The only difference between the memory consumption of with_offloading - # and without_offloading_one_layer should be the size of the FP8 weights cache, - # which is not offloaded to the CPU. - memory_consumption_diff = abs(with_offloading - without_offloading_one_layer) - assert ( - memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON + # Construct model + modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)] + if model_name in ["multihead_attention", "transformer_layer"]: + available_backends, *_ = get_available_attention_backends( + model_config["small"], + qkv_dtype=torch.bfloat16, + qkv_layout="sbhd_sbhd_sbhd", + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("Fused attention backend not available.") + os.environ["NVTE_FLASH_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + # Warmup + _warmup_model(modules_list, quantization_recipe) + + # Measure cached memory after forward pass + memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False) + memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True) + + # Check for expected memory usage + assert memory_with_offload < memory_without_offload + memory_from_cached_weights = _estimate_cached_weight_size( + model_name, + modules_list, + quantization_recipe, ) + assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 7bfe506f2..90e624c94 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -2,9 +2,7 @@ # # See LICENSE for license information. -from dataclasses import dataclass -import itertools -from typing import Iterable, List, Tuple, Union +from typing import Iterable, List, Union import pytest import torch @@ -23,43 +21,28 @@ from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe - +from utils import ModelConfig, reset_rng_states # Check if FP8 is supported. -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() -) -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() - - -# Record initial RNG state. -seed = 1234 -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() - - -@dataclass -class ModelConfig: - """Data tensor dimensions within Transformer model""" - - sequence_length: int - batch_size: int - hidden_size: int - num_heads: int - kv_channels: int - - -model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} - -fp8_recipes = [ - recipe.DelayedScaling(), - recipe.MXFP8BlockScaling(), - recipe.Float8CurrentScaling(), - recipe.Float8BlockScaling(), -] +fp8_available, _ = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() +mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() + +# Reset RNG states. +reset_rng_states() + +model_configs = { + "small": ModelConfig(32, 2, 2, 32), +} + +fp8_recipes = [] +if mxfp8_available: + fp8_recipes.append(recipe.MXFP8BlockScaling()) +if fp8_block_scaling_available: + fp8_recipes.append(recipe.Float8BlockScaling()) +if fp8_available: + fp8_recipes.append(recipe.Float8CurrentScaling()) + fp8_recipes.append(recipe.DelayedScaling()) # Supported data types dtypes: List[torch.dtype] = [torch.float32, torch.float16] @@ -67,12 +50,6 @@ class ModelConfig: dtypes.append(torch.bfloat16) -def reset_rng_states() -> None: - """Revert to initial RNG state.""" - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) - - @pytest.fixture(autouse=True) def reset_global_fp8_state(): yield @@ -107,7 +84,7 @@ def generate_data( """Generate synthetic data.""" gen_func = torch.ones if warmup else torch.randn return gen_func( - model_config.sequence_length, + model_config.max_seqlen_q, model_config.batch_size, model_config.hidden_size, device="cuda", @@ -145,10 +122,12 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: # Supported modules _test_cuda_graphs_modules: List[str] = [ + # Put linear first to test the case where the cuda context might not be set in + # creating TMA descriptor for MXFP8 quantization. + "linear", "transformer", "layernorm_mlp", "layernorm_linear", - "linear", "mha", "linear_op", ] @@ -298,35 +277,27 @@ def _test_cuda_graphs( @pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None]) def test_make_graphed_callables( *, module: str, model_config: str = "small", num_layers: int = 3, dtype: torch.dtype, - fp8: bool, fp8_params: bool, fp8_recipe: recipe.Recipe, fp8_weight_caching: bool = False, ) -> None: - # Skip invalid configurations. - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) + fp8 = fp8_recipe is not None if fp8_params and not fp8: pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - - if fp8_recipe.float8_block_scaling() and module == "linear_op": + if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op": pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") + # Run model with different CUDA graph settings. model_config = model_configs[model_config] kwargs = dict( @@ -339,9 +310,11 @@ def test_make_graphed_callables( fp8_weight_caching=fp8_weight_caching, fp8_recipe=fp8_recipe, ) - outputs = _test_cuda_graphs(graph_mode="none", **kwargs) + # Put graphed callables first to test the case where the cuda context might not be set in + # creating TMA descriptor for MXFP8 quantization. graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs) graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs) + outputs = _test_cuda_graphs(graph_mode="none", **kwargs) # Check that results match. assert_all_equal(outputs, graph_outputs_mode1) @@ -357,7 +330,6 @@ def test_make_graphed_callables( ] -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize( "module", _test_make_graphed_callables_with_fp8_weight_caching_modules, @@ -373,7 +345,6 @@ def test_make_graphed_callables_with_fp8_weight_caching( test_make_graphed_callables( module=module, dtype=torch.float32, - fp8=True, fp8_params=fp8_params, fp8_recipe=fp8_recipe, fp8_weight_caching=True, @@ -389,7 +360,7 @@ def generate_data_for_dot_product_attention( gen_func = torch.ones if warmup else torch.randn return [ gen_func( - model_config.sequence_length, + model_config.max_seqlen_q, model_config.batch_size, model_config.num_heads, model_config.kv_channels, @@ -483,8 +454,8 @@ def _test_cuda_graphs_with_kwargs( ( model_config.batch_size, 1, - model_config.sequence_length, - model_config.sequence_length, + model_config.max_seqlen_q, + model_config.max_seqlen_kv, ), dtype=torch.bool, device="cuda", @@ -510,8 +481,8 @@ def _test_cuda_graphs_with_kwargs( ( model_config.batch_size, 1, - model_config.sequence_length, - model_config.sequence_length, + model_config.max_seqlen_q, + model_config.max_seqlen_kv, ), dtype=torch.bool, device="cuda", diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 63833b564..c5bc7180a 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -222,7 +222,7 @@ def test_quantize_dequantize_compact_format( rowwise=True, columnwise=dq_columnwise, block_scaling_dim=block_scaling_dim, - all_gather_usage=True, + all_gather_usage=(block_scaling_dim == 1), ) self._test_quantize_dequantize( quantizer=quantizer, diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 3527bb9a6..0bd11d941 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -4,7 +4,6 @@ # # See LICENSE for license information. -from itertools import product import copy from contextlib import nullcontext @@ -116,13 +115,6 @@ def test_half(self): def test_bfloat16(self): self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True) - @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required") - def test_multi_device(self): - devices = ("cuda:0", "cuda:1") - for current_dev, tensor_dev in product(devices, devices): - with torch.cuda.device(current_dev): - self.gen_single_type_test(param_type=torch.float, device=tensor_dev) - def test_multi_params(self): sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] @@ -537,13 +529,6 @@ def test_float(self): def test_half(self): self.gen_single_type_test(param_type=torch.float16) - @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required") - def test_multi_device(self): - devices = ("cuda:0", "cuda:1") - for current_dev, tensor_dev in product(devices, devices): - with torch.cuda.device(current_dev): - self.gen_single_type_test(param_type=torch.float, device=tensor_dev) - class Model(torch.nn.Module): def __init__(self): diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index ae25af949..62d80b552 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -1,25 +1,32 @@ # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -from typing import Callable, Tuple, Union +from typing import Callable, Tuple, Union, List import math import torch import pytest from transformer_engine.pytorch.attention.rope import ( RotaryPositionEmbedding, apply_rotary_pos_emb, + apply_fused_qkv_rotary_pos_emb, ) # Gradient is a broadcasted scalar -def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: - return output.sum() * 2 +def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: + if isinstance(output, List): + return sum(t.sum() * 2 for t in output) + else: + return output.sum() * 2 # Gradient is a full tensor -def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: - t = torch.ones_like(output) - return torch.sum(output * t) +def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: + if isinstance(output, List): + return sum(torch.sum(t * torch.ones_like(t)) for t in output) + else: + t = torch.ones_like(output) + return torch.sum(output * t) @pytest.mark.parametrize("start_positions", [True, False]) @@ -238,3 +245,131 @@ def test_fused_rope_thd( torch.testing.assert_close(grad_fused, grad_unfused) assert output_fused.is_contiguous() + + +@pytest.mark.parametrize("start_positions", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096]) +@pytest.mark.parametrize("hidden_size", [64, 128, 256]) +@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) +@pytest.mark.parametrize("margin", [0, 10]) +@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) +@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) +@pytest.mark.parametrize("cp_size", [1, 2]) +@pytest.mark.parametrize("interleaved", [True, False]) +def test_fused_qkv_rope( + dtype: torch.dtype, + seq_length: int, + hidden_size: int, + rotary_percent: float, + margin: int, + tensor_format: str, + loss_func: Callable, + cp_size: int, + interleaved: bool, + start_positions: bool, +) -> None: + if margin == 0 and start_positions == True: + # This makes sure that the `start_positions` offsets being applied + # are with the maximum length of the rope embeddings. + pytest.skip("Skipping test with margin=0 and start_positions=True") + + if start_positions == True and cp_size > 1: + # `start_positions` is only supported for `cp_size=1` and inference. + pytest.skip("Skipping test with cp_size>1 and start_positions=True") + + if seq_length - margin < 0: + pytest.skip("Skipping test with seq_length - margin < 0") + + device = torch.device("cuda:0") + batch_size, head_num = 2, 64 + + t = torch.rand( + (seq_length - margin, batch_size, head_num, hidden_size * 6), + dtype=dtype, + device=device, + ) + + # Get arbitrary offsets to be used with RoPE for all the sequences + start_positions = ( + torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device) + if start_positions + else None + ) + + if tensor_format == "bshd": + t = t.transpose(0, 1).contiguous() + t.requires_grad = True + + rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb_q = rotary_pos_emb_q(seq_length * cp_size) + rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb_k = rotary_pos_emb_k(seq_length * cp_size) + + for cp_rank in range(cp_size): + # unfused + # The fused kernel computes in float32 internally, so we force the unfused func to use float32 + # for more accurate comparison + + t_clone = t.clone() + (query, key, value) = torch.split( + t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3 + ) + query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size) + + query_unfused = apply_rotary_pos_emb( + query, + emb_q, + tensor_format=tensor_format, + start_positions=start_positions, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + key_unfused = apply_rotary_pos_emb( + key, + emb_k, + tensor_format=tensor_format, + start_positions=start_positions, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + value_unfused = value + loss_unfused = loss_func([query_unfused, key_unfused, value_unfused]) + + if not isinstance(start_positions, torch.Tensor): + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() + + t.grad = None + + # fused + query_fused, key_fused, value_fused = apply_fused_qkv_rotary_pos_emb( + t, + emb_q, + emb_k, + tensor_format=tensor_format, + start_positions=start_positions, + interleaved=interleaved, + cp_size=cp_size, + cp_rank=cp_rank, + qkv_split_arg_list=[hidden_size * 4, hidden_size, hidden_size], + ) + loss_fused = loss_func([query_fused, key_fused, value_fused]) + + if not isinstance(start_positions, torch.Tensor): + loss_fused.backward() + grad_fused = t.grad.detach().clone() + t.grad = None + + torch.testing.assert_close(query_fused, query_unfused) + torch.testing.assert_close(key_fused, key_unfused) + torch.testing.assert_close(value_fused, value_unfused) + + if not isinstance(start_positions, torch.Tensor): + torch.testing.assert_close(grad_fused, grad_unfused) diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index d2cb85dd3..fa134ba4b 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -2,8 +2,7 @@ # # See LICENSE for license information. import torch -import math -from typing import Optional, Dict +from typing import Optional from transformer_engine.pytorch.router import ( fused_topk_with_score_function, fused_compute_score_for_moe_aux_loss, @@ -149,11 +148,21 @@ def run_comparison( # Set some parameters if score_function == "sigmoid": # Construct the special logits to avoid inf in the sigmoid function - offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 - logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 + logits = ( + torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 + ) logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) else: - logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 + logits = ( + torch.arange( + -num_tokens * num_experts // 2, + num_tokens * num_experts // 2, + device="cuda", + dtype=dtype, + ) + * 1e-4 + ) logits = logits.view(num_tokens, num_experts) logits.requires_grad = True if enable_bias and score_function == "sigmoid": @@ -282,11 +291,21 @@ def test_topk_softmax( def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): if score_function == "sigmoid": # Construct the special logits to avoid inf in the sigmoid function - offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 - logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 + logits = ( + torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 + ) logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) else: - logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 + logits = ( + torch.arange( + -num_tokens * num_experts // 2, + num_tokens * num_experts // 2, + device="cuda", + dtype=dtype, + ) + * 1e-4 + ) logits = logits.view(num_tokens, num_experts) logits.requires_grad = True @@ -322,8 +341,8 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f @pytest.mark.parametrize("topk", [4]) def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): # Construct the special probs to avoid inf in the sigmoid function - offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 - probs = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 + probs = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) probs = probs.view(num_tokens, num_experts) probs.requires_grad = True @@ -380,15 +399,12 @@ def profile_topk_softmax( if __name__ == "__main__": - test_fused_scores_for_aux_loss( - dtype=torch.float32, num_tokens=2, num_experts=32, topk=8, score_function="softmax" + test_topk_softmax( + dtype=torch.float32, + num_tokens=1024, + num_experts=128, + topk=4, + use_pre_softmax=False, + group_topk=None, + scaling_factor=None, ) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=32, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=128, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=256, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=32, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=128, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=256, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=32, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=128, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=256, topk=4) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 1db81ec23..500b25f58 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -22,10 +22,13 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops.fused import ( - BackwardBiasActivation, + BackwardActivationBias, + BackwardAddRMSNorm, BackwardLinearAdd, + BackwardLinearScale, ForwardLinearBiasActivation, ForwardLinearBiasAdd, + ForwardLinearScaleAdd, ) from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( @@ -40,9 +43,7 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION # Import utility functions -_current_file = pathlib.Path(__file__).resolve() -sys.path.append(str(_current_file.parent)) -from utils import dtype_tols, make_recipe +from utils import dtype_tols, make_recipe, reset_rng_states # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -266,16 +267,72 @@ def test_module_groups(self) -> None: model(torch.zeros(1)) assert len(model._module_groups) == 6 + def test_extra_tensors(self, size: int = 16) -> None: + """Check that extra inputs are distributed properly between module groups + and that extra outputs are properly collected""" + + # Construct sequential container + bias = te_ops.Bias(size=size, device="cpu") + with torch.no_grad(): + bias.bias.copy_(torch.rand((size,))) + model = te_ops.Sequential( # | Inputs | Outputs + torch.nn.Identity(), # | x1 | x1 + te_ops.MakeExtraOutput(in_place=True), # | x1 | x1 [x1] + bias, # | x1 | h1 (= x1 + b) + te_ops.MakeExtraOutput(in_place=True), # | h1 | h1 [h1] + te_ops.AddExtraInput(in_place=True), # | h1 [x2] | x2 (= x2 + h1) + te_ops.MakeExtraOutput(in_place=True), # | x2 | x2 [x2] + torch.nn.Identity(), # | x2 | x2 + bias, # | x2 | h2 (= x2 + b) + te_ops.AddExtraInput(in_place=True), # | h2 [x3] | x3 (= x3 + h2) + te_ops.MakeExtraOutput(in_place=True), # | x3 | x3 [x3] + te_ops.AddExtraInput(in_place=True), # | x3 [x4] | x4 (= x4 + x3) + torch.nn.Identity(), # | x4 | x4 + te_ops.Identity(), # | x4 | x4 + te_ops.MakeExtraOutput(in_place=True), # | x4 | x4 [x4] + te_ops.Identity(), # | x4 | x4 + ) + + # Create input tensors + x1 = torch.rand((size,)) + x2 = torch.rand((size,)) + x3 = torch.rand((size,)) + x4 = torch.rand((size,)) + + # Save original input tensor values + x1_orig = x1.clone() + x2_orig = x2.clone() + x3_orig = x3.clone() + x4_orig = x4.clone() + + # Run forward + ys = model(x1, x2, x3, x4) + + # Check whether outputs match (x4, x1, h1, x2, x3, x4) + assert len(ys) == 6 + assert ys[0].data_ptr() == x4.data_ptr() + assert ys[1].data_ptr() == x1.data_ptr() + assert ys[2].data_ptr() not in [x.data_ptr() for x in (x1, x2, x3, x4)] + assert ys[3].data_ptr() == x2.data_ptr() + assert ys[4].data_ptr() == x3.data_ptr() + assert ys[5].data_ptr() == x4.data_ptr() + + # Check whether tensors have correct values + b = bias.bias + h1 = ys[2] + torch.testing.assert_close(x1, x1_orig) + torch.testing.assert_close(h1, x1_orig + b) + torch.testing.assert_close(x2, x2_orig + h1) + torch.testing.assert_close(x3, x3_orig + x2 + b) + torch.testing.assert_close(x4, x4_orig + x3) + class TestFuser: """Tests for operation fusion infrastructure""" @staticmethod def setup_class(cls) -> None: - # Configure RNG - seed = 1234 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + reset_rng_states() @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_scale_update( @@ -489,10 +546,7 @@ class TestBasicOps: @staticmethod def setup_class(cls) -> None: - # Configure RNG - seed = 1234 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + reset_rng_states() @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) @@ -790,10 +844,9 @@ def _test_basic_linear( pytest.skip("FP8 output is only supported with FP8 GEMMs") if quantized_grad_input and not quantized_compute: pytest.skip("FP8 grad input is only supported with FP8 GEMMs") - if quantization == "mxfp8" and quantized_output: - pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") - if quantization == "mxfp8" and quantized_grad_input: - pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs") + if quantization not in (None, "fp8"): + if quantized_output or quantized_grad_input: + pytest.skip("Recipe does not support quantized GEMM output") # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1365,18 +1418,17 @@ def test_l2normalization( dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(y_test, y_ref, **tols) - # L2Norm backward pass requires slightly looser atol for bfloat16 - if dtype == torch.bfloat16: - tols["atol"] = 2e-3 torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("in_place", (True, False)) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("quantization", _quantization_list) - def test_add_in_place( + def test_add_extra_input( self, *, in_shape: Iterable[int] = (32, 32), + in_place: bool, dtype: torch.dtype, device: torch.device, quantization: Optional[str], @@ -1422,7 +1474,7 @@ def test_add_in_place( dx2_ref = dy_ref # Implementation with fusible operation - op = te_ops.AddInPlace() + op = te_ops.AddExtraInput(in_place=in_place) y_test = op(x1_test, x2_test) y_test.backward(dy_test) @@ -1437,6 +1489,7 @@ def test_add_in_place( torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0) torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0) + @pytest.mark.parametrize("in_place", (True, False)) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("quantization", _quantization_list) @@ -1444,6 +1497,7 @@ def test_make_extra_output( self, *, in_shape: Iterable[int] = (32, 32), + in_place: bool, dtype: torch.dtype, device: torch.device, quantization: Optional[str], @@ -1489,7 +1543,7 @@ def test_make_extra_output( (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward() # Implementation with fusible operation - op = te_ops.MakeExtraOutput() + op = te_ops.MakeExtraOutput(in_place=in_place) y1_test, y2_test = op(x_test) (y1_test * dy1_test + y2_test * dy2_test).sum().backward() @@ -1502,7 +1556,10 @@ def test_make_extra_output( torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0) torch.testing.assert_close(dx_test, x_ref.grad, **tols) - @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) + @pytest.mark.parametrize( + "activation", + ("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"), + ) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) @@ -1521,7 +1578,7 @@ def test_activation( # Tensor dimensions in_shape = list(out_shape) - if activation in ("geglu", "reglu", "swiglu"): + if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"): in_shape[-1] *= 2 # Skip invalid configurations @@ -1548,14 +1605,26 @@ def test_activation( y_ref: torch.Tensor if activation == "gelu": y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh") - elif activation == "relu": - y_ref = torch.nn.functional.relu(x_ref) elif activation == "geglu": x1, x2 = x_ref.chunk(2, dim=-1) y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2 + elif activation == "qgelu": + y_ref = x_ref * torch.sigmoid(1.702 * x_ref) + elif activation == "qgeglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = x1 * torch.sigmoid(1.702 * x1) * x2 + elif activation == "relu": + y_ref = torch.nn.functional.relu(x_ref) elif activation == "reglu": x1, x2 = x_ref.chunk(2, dim=-1) y_ref = torch.nn.functional.relu(x1) * x2 + elif activation == "srelu": + y_ref = torch.nn.functional.relu(x_ref) ** 2 + elif activation == "sreglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.relu(x1) ** 2 * x2 + elif activation == "silu": + y_ref = torch.nn.functional.silu(x_ref) elif activation == "swiglu": x1, x2 = x_ref.chunk(2, dim=-1) y_ref = torch.nn.functional.silu(x1) * x2 @@ -1567,9 +1636,14 @@ def test_activation( recipe = make_recipe(quantization) make_op = dict( gelu=te_ops.GELU, - relu=te_ops.ReLU, geglu=te_ops.GEGLU, + qgelu=te_ops.QGELU, + qgeglu=te_ops.QGEGLU, + relu=te_ops.ReLU, reglu=te_ops.ReGLU, + srelu=te_ops.SReLU, + sreglu=te_ops.SReGLU, + silu=te_ops.SiLU, swiglu=te_ops.SwiGLU, )[activation] forward = te_ops.Sequential( @@ -1657,16 +1731,131 @@ def test_swiglu( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) + @pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("device", _devices) + def test_constant_scale( + self, + *, + scale: float, + shape: Iterable[int], + dtype: torch.dtype, + device: torch.device, + ): + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = scale * x_ref + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.ConstantScale(scale) + y_test = op(x_test) + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + + @pytest.mark.parametrize("prob", (0.0625, 0.5, 0.75)) + @pytest.mark.parametrize("is_training", (True, False)) + @pytest.mark.parametrize("quantization", (None, "fp8_current_scaling")) + @pytest.mark.parametrize("shape", ((101,), (2, 4, 16), (128, 128))) + @pytest.mark.parametrize("dtype", _dtypes) + def test_dropout( + self, + *, + prob: float, + is_training: bool, + quantization: Optional[str], + shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + ): + + # Skip invalid configurations + quantized_input = quantization is not None + maybe_skip_quantization(quantization, dims=shape, device=device) + + # Random data + # Note: Shift values to make sure inputs are non-zero + x_ref, x_test = make_reference_and_test_tensors( + shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + test_is_quantized=quantized_input, + ) + with torch.no_grad(): + x_test += 1 + x_ref.copy_(x_test) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Apply dropout + op = te_ops.Dropout(prob) + if is_training: + op.train() + else: + op.eval() + y_test = op(x_test) + y_test.backward(dy_test) + + # Check values + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + if is_training: + tols = dtype_tols(dtype) + mask = ((y_test != 0) / (1 - prob)).to(dtype=dtype) + torch.testing.assert_close(y_test, x_ref * mask, **tols) + torch.testing.assert_close(dx_test, dy_ref * mask, **tols) + else: + torch.testing.assert_close(y_test, x_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_test, dy_ref, rtol=0, atol=0) + + # Hypothesis testing for number of zeros + # Note: A Bernoulli random variable with probability p has + # mean p and standard deviation sqrt(p*(1-p)). By the central + # limit theorem, the mean of n iid Bernoulli variables + # converges to a normal random variable with mean p and + # standard deviation sqrt(p*(1-p)/n). If the observed mean is + # below the 0.5th or above the 99.5th percentiles, then the + # p-value is less than 1% and we assume that the dropout + # distribution is incorrect. + if is_training: + prob_observed = 1 - torch.count_nonzero(y_test).item() / y_test.numel() + z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y_test.numel()) + assert ( + abs(z_score) < 2.5758 + ), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})" + class TestFusedOps: """Tests for fused operations""" @staticmethod def setup_class(cls) -> None: - # Configure RNG - seed = 1234 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + reset_rng_states() @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5))) @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1))) @@ -1853,7 +2042,7 @@ def test_forward_linear_bias_add( device=device, dtype=dtype, ), - te_ops.AddInPlace(), + te_ops.AddExtraInput(in_place=True), ) with torch.no_grad(): model[0].weight.copy_(w_test) @@ -1890,11 +2079,114 @@ def test_forward_linear_bias_add( db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, b_ref.grad, **tols) + @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + def test_forward_linear_scale_add( + self, + *, + scale: float, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool = False, + ) -> None: + """Forward GEMM + scale + add""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + if quantized_compute and dtype not in (torch.float16, torch.bfloat16): + pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") + + # Random data + x1_ref, x1_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + x2_ref, x2_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x1_ref, w_ref) * scale + x2_ref + y_ref.backward(dy_ref) + + # Implementation with fusible operations + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model = te_ops.Sequential( + te_ops.Linear( + in_features, + out_features, + bias=False, + device=device, + dtype=dtype, + ), + te_ops.ConstantScale(scale), + te_ops.AddExtraInput(in_place=True), + te_ops.Quantize(), + ) + with torch.no_grad(): + model[0].weight.copy_(w_test) + del w_test + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = model(x1_test, x2_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + forward_ops = model._module_groups[0]._forward_ops + assert len(forward_ops) == 2 + assert isinstance(forward_ops[0][0], ForwardLinearScaleAdd) + assert isinstance(forward_ops[1][0], te_ops.Quantize) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") + dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx1_test, x1_ref.grad, **tols) + torch.testing.assert_close(dx2_test, x2_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + @pytest.mark.parametrize("activation", ("relu", "gelu")) @pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) - def test_backward_bias_activation( + def test_backward_activation_bias( self, *, activation: str, @@ -1903,7 +2195,7 @@ def test_backward_bias_activation( device: torch.device = "cuda", quantization: Optional[str], ) -> None: - """Backward dbias + dact + quantize""" + """Backward dact + dbias + quantize""" # Tensor dimensions in_shape = list(out_shape) @@ -1960,9 +2252,9 @@ def test_backward_bias_activation( # Check that backward operations have been fused backward_ops = model._module_groups[0]._backward_ops - if with_quantization and quantization in ["fp8_delayed_scaling", "mxfp8"]: + if with_quantization: assert len(backward_ops) == 2 - assert isinstance(backward_ops[0][0], BackwardBiasActivation) + assert isinstance(backward_ops[0][0], BackwardActivationBias) assert isinstance(backward_ops[1][0], te_ops.Quantize) else: assert len(backward_ops) == 3 @@ -1975,6 +2267,7 @@ def test_backward_bias_activation( if with_quantization: tols = dtype_tols(tex.DType.kFloat8E4M3) + # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") db_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu") @@ -1982,6 +2275,94 @@ def test_backward_bias_activation( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols) + @pytest.mark.parametrize("weight_shape", ((19,), (64,))) + @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("zero_centered_gamma", (False, True)) + def test_backward_add_rmsnorm( + self, + *, + weight_shape: Iterable[int], + in_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + eps: float = 0.3, + zero_centered_gamma: bool, + ) -> None: + """Fused backward RMNorm + add""" + + # Make input and weight shapes consistent + in_shape = list(in_shape)[:-1] + list(weight_shape) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + ) + dy1_ref, dy1_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + dy2_ref, dy2_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + inner_dims = tuple(range(len(in_shape) - len(weight_shape), len(in_shape))) + var_ref = x_ref.square().sum(dim=inner_dims, keepdim=True) / math.prod(weight_shape) + if zero_centered_gamma: + y1_ref = x_ref / torch.sqrt(eps + var_ref) * (1 + w_ref) + else: + y1_ref = x_ref / torch.sqrt(eps + var_ref) * w_ref + y2_ref = x_ref + (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward() + + # Implementation with fusible operations + model = te_ops.Sequential( + te_ops.MakeExtraOutput(), + te_ops.RMSNorm( + weight_shape, + eps=eps, + device=device, + dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + ), + ) + with torch.no_grad(): + model[1].weight.copy_(w_test) + del w_test + y1_test, y2_test = model(x_test) + (y1_test * dy1_test + y2_test * dy2_test).sum().backward() + + # Check that backward operations have been fused + backward_ops = model._module_groups[0]._backward_ops + assert len(backward_ops) == 1 + assert isinstance(backward_ops[0][0], BackwardAddRMSNorm) + + # Expected numerical error + tols = dtype_tols(dtype) + + # Check results + y1_test = y1_test.to(dtype=torch.float64, device="cpu") + y2_test = y2_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y1_test, y1_ref, **tols) + torch.testing.assert_close(y2_test, y2_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) def test_backward_linear_add( @@ -2045,7 +2426,7 @@ def test_backward_linear_add( recipe = make_recipe(quantization) with te.fp8_model_init(enabled=quantized_weight): model = te_ops.Sequential( - te_ops.MakeExtraOutput(), + te_ops.MakeExtraOutput(in_place=True), te_ops.Linear( in_features, out_features, @@ -2083,16 +2464,106 @@ def test_backward_linear_add( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols) + @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + def test_backward_linear_scale( + self, + *, + scale: float, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool = False, + ) -> None: + """Backward dgrad GEMM + scale""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + if quantized_compute and dtype not in (torch.float16, torch.bfloat16): + pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref) * scale + y_ref.backward(dy_ref) + + # Implementation with fusible operations + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight): + model = te_ops.Sequential( + te_ops.Linear( + in_features, + out_features, + bias=False, + device=device, + dtype=dtype, + ), + te_ops.ConstantScale(scale), + ) + with torch.no_grad(): + model[0].weight.copy_(w_test) + del w_test + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = model(x_test) + (y_test * dy_test).sum().backward() + + # Check that backward operations have been fused + backward_ops = model._module_groups[0]._backward_ops + assert len(backward_ops) == 1 + assert isinstance(backward_ops[0][0], BackwardLinearScale) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + class TestCheckpointing: """Tests for checkpointing""" @staticmethod def setup_class(cls) -> None: - # Configure RNG - seed = 1234 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + reset_rng_states() @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_weight", (False, True)) @@ -2204,11 +2675,9 @@ class TestSequentialModules: @staticmethod def setup_class(cls) -> None: - # Configure RNG - seed = 1234 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + reset_rng_states() + @pytest.mark.parametrize("requires_grad", (False, True)) @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm")) @pytest.mark.parametrize("quantized_compute", (False, True)) @@ -2218,6 +2687,7 @@ def setup_class(cls) -> None: def test_layernorm_mlp( self, *, + requires_grad: bool, bias: bool, normalization: str, quantized_compute: bool, @@ -2258,6 +2728,7 @@ def test_layernorm_mlp( quantization=quantization, test_dtype=dtype, test_device=device, + requires_grad=requires_grad, ) _, dy_test = make_reference_and_test_tensors( in_shape, diff --git a/tests/pytorch/test_hf_integration.py b/tests/pytorch/test_hf_integration.py index 0b2468510..e74b16022 100644 --- a/tests/pytorch/test_hf_integration.py +++ b/tests/pytorch/test_hf_integration.py @@ -7,7 +7,6 @@ from transformers.modeling_utils import PreTrainedModel from transformer_engine.pytorch.transformer import TransformerLayer -from transformer_engine.pytorch.utils import is_bf16_compatible class SimpleTEModel(PreTrainedModel): diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index b2885d677..a4dfd64ba 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1,10 +1,9 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -from collections import OrderedDict import math import os from typing import Dict, List, Tuple, Optional @@ -43,31 +42,34 @@ Fp8Padding, Fp8Unpadding, ) -from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils as fa_utils from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace -from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version +from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.common import recipe import transformer_engine_torch as tex +from utils import ModelConfig, reset_rng_states, get_available_attention_backends +if IS_HIP_EXTENSION: + from utils import EnvVarCleaner + # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() -) +fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() sm_80plus = get_device_compute_capability() >= (8, 0) seed = 1234 -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -# Record initial RNG state from script run. -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() +# Reset RNG states. +reset_rng_states() if torch.__version__ >= '2.7.0': torch._dynamo.config.recompile_limit = 16 @@ -83,24 +85,12 @@ def rocm_attn_backend() -> tuple[bool, bool, bool]: int(os.getenv("NVTE_FUSED_ATTN_CK", "1")) != 0) -class ModelConfig: - def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len): - self.hidden_size = hidden_size - self.eps = eps - self.num_attention_heads = num_attention_heads - self.embed = embed - self.num_layers = num_layers - self.seq_len = seq_len - - model_configs = { - "small": ModelConfig(128, 1e-5, 8, 36, 4, 128), - "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048), + "small": ModelConfig(1, 128, 8, 16, num_layers=4), + "126m": ModelConfig(1, 2048, 12, 64, num_layers=12), } - model_configs_inference = { - # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len - "126m": ModelConfig(768, 1e-5, 12, 64, 12, 256), + "126m": ModelConfig(1, 256, 12, 64, num_layers=12), } backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"] module_inference = ["TransformerLayer", "MultiheadAttention"] @@ -114,7 +104,18 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq all_boolean = [True, False] -all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] +all_activations = [ + "gelu", + "geglu", + "qgelu", + "qgeglu", + "relu", + "reglu", + "srelu", + "sreglu", + "silu", + "swiglu", +] all_normalizations = ["LayerNorm", "RMSNorm"] @@ -134,12 +135,39 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"], ) -fp8_recipes = [ - recipe.MXFP8BlockScaling(), - recipe.DelayedScaling(), - recipe.Float8CurrentScaling(), - recipe.Float8BlockScaling(), -] + +fp8_recipes = [] +if mxfp8_available: + fp8_recipes.append(recipe.MXFP8BlockScaling()) +if fp8_block_scaling_available: + fp8_recipes.append(recipe.Float8BlockScaling()) +if fp8_available: + fp8_recipes.append(recipe.Float8CurrentScaling()) + fp8_recipes.append(recipe.DelayedScaling()) + +use_cutlass_grouped_gemm = [False] +# Only enable cutlass grouped gemm on Hopper +if torch.cuda.get_device_capability() == (9, 0): + use_cutlass_grouped_gemm.append(True) + + +def is_fused_attn_available( + config: ModelConfig, + dtype: torch.dtype, + qkv_layout="bshd_bshd_bshd", + is_training=True, + deterministic=False, +): + _, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + is_training=is_training, + deterministic=deterministic, + ) + if IS_HIP_EXTENSION: + return fused_attn_backends != [] + return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends def get_causal_attn_mask(sq: int) -> torch.Tensor: @@ -204,40 +232,12 @@ def assert_allclose( raise AssertionError(msg) -def reset_rng_states() -> None: - """revert back to initial RNG state.""" - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) - - @pytest.fixture(autouse=True) def reset_global_fp8_state(): yield FP8GlobalStateManager.reset() -class EnvVarCleaner: - def __init__(self, envs_): - self.envs = envs_ - self.flags = {} - for env in self.envs: - if env in os.environ: - self.flags[env] = os.environ[env] - def __del__(self): - for env in self.envs: - if env in self.flags: - os.environ[env] = self.flags[env] - else: - os.environ.pop(env, None) - - -@pytest.fixture -def reset_test_envs(): - env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", - "NVTE_BIAS_GELU_NVFUSION"]) - yield - - class TorchScaledMaskedSoftmax(nn.Module): def __init__(self) -> None: super().__init__() @@ -488,13 +488,16 @@ def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor: _supported_act = { - "geglu": nn.GELU(approximate="tanh"), "gelu": nn.GELU(approximate="tanh"), - "reglu": nn.ReLU(), - "relu": nn.ReLU(), - "swiglu": nn.SiLU(), + "geglu": nn.GELU(approximate="tanh"), "qgelu": TorchQuickGELU(), + "qgeglu": TorchQuickGELU(), + "relu": nn.ReLU(), + "reglu": nn.ReLU(), "srelu": TorchSquaredRELU(), + "sreglu": TorchSquaredRELU(), + "silu": nn.SiLU(), + "swiglu": nn.SiLU(), } @@ -584,13 +587,13 @@ def _test_e2e_selective_recompute( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -599,13 +602,13 @@ def _test_e2e_selective_recompute( ) te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): te_out = block( @@ -631,14 +634,8 @@ def _test_e2e_selective_recompute( @pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -680,13 +677,13 @@ def _test_e2e_full_recompute( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -695,14 +692,14 @@ def _test_e2e_full_recompute( ) te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=use_reentrant, ) if use_reentrant: te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if recompute: @@ -745,30 +742,25 @@ def _test_e2e_full_recompute( @pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_reentrant", all_boolean) -@pytest.mark.usefixtures("reset_test_envs") def test_gpt_full_activation_recompute( dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant ): - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): - if (dtype == torch.bfloat16 + if (dtype == torch.bfloat16 and not fp8 - and not use_reentrant - and recipe.float8_per_tensor_scaling() + and not use_reentrant + and recipe.float8_per_tensor_scaling() ): - pytest.skip("hipBLASLt does not provide suitable algorithms on MI350 for this config.") + pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.") config = model_configs[model] torch.compiler.reset() # avoid cache size limit overflow if not use_reentrant: + if IS_HIP_EXTENSION: + env = EnvVarCleaner(["NVTE_BIAS_GELU_NVFUSION"]) # Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0" @@ -825,13 +817,13 @@ def _test_e2e_checkpointing_get_model(config, dtype): return TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -843,7 +835,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= reset_rng_states() te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -873,14 +865,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= if p.requires_grad: param_grads.append(p.grad.clone()) - global _cpu_rng_state, _cuda_rng_state _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() del block block = _test_e2e_checkpointing_get_model(config, dtype) block.load_state_dict(torch.load(path, weights_only=False)) - reset_rng_states() + torch.set_rng_state(_cpu_rng_state) + torch.cuda.set_rng_state(_cuda_rng_state) for p in block.parameters(): if p.requires_grad: @@ -913,6 +905,9 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("model", ["126m"]) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] + if not is_fused_attn_available(config, dtype, deterministic=True): + pytest.skip("No attention backend available.") + outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) @@ -939,13 +934,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): reset_rng_states() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) inp_hidden_states.retain_grad() - inp_attn_mask = get_causal_attn_mask(config.seq_len) + inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) out = block(inp_hidden_states, attention_mask=inp_attn_mask) loss = out.sum() @@ -965,11 +960,15 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): config = model_configs[model] + if not is_fused_attn_available( + config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True + ): + pytest.skip("No attention backend available.") te_gpt = TransformerLayer( hidden_size=config.hidden_size, ffn_hidden_size=4 * config.hidden_size, - num_attention_heads=config.num_attention_heads, + num_attention_heads=config.num_heads, layernorm_epsilon=config.eps, attention_dropout=0.1, hidden_dropout=0.1, @@ -984,7 +983,7 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): TorchGPT( config.hidden_size, config.eps, - config.num_attention_heads, + config.num_heads, parallel_attention_mlp=parallel_attention_mlp, ) .to(dtype=dtype) @@ -1045,13 +1044,13 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): reset_rng_states() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) inp_hidden_states.retain_grad() - inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None + inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) if mask_type == "causal" else None forward_kwargs = {} if te: @@ -1076,10 +1075,14 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): @pytest.mark.parametrize("mask_type", mask_types) def test_mha_accuracy(dtype, bs, model, mask_type): config = model_configs[model] + if not is_fused_attn_available( + config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True + ): + pytest.skip("No attention backend available.") te_mha = MultiheadAttention( config.hidden_size, - config.num_attention_heads, + config.num_heads, fuse_qkv_params=True, params_dtype=dtype, qkv_weight_interleaved=False, @@ -1090,7 +1093,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type): torch_mha = ( TorchMHA( config.hidden_size, - config.num_attention_heads, + config.num_heads, ) .to(dtype=dtype) .cuda() @@ -1136,7 +1139,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -1168,7 +1171,7 @@ def _test_granular_accuracy_with_fp8(block, bs, dtype, config): reset_rng_states() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -1192,11 +1195,12 @@ def _test_dpa_accuracy(block, bs, dtype, config): reset_rng_states() mask = torch.triu( - torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1 + torch.ones(config.max_seqlen_q, config.max_seqlen_kv, dtype=torch.bool, device="cuda"), + diagonal=1, ) query, key, value = [ torch.randn( - (config.seq_len, bs, config.num_attention_heads, config.embed), + (config.max_seqlen_q, bs, config.num_heads, config.kv_channels), dtype=dtype, device="cuda", requires_grad=True, @@ -1225,8 +1229,8 @@ def test_dpa_accuracy(dtype, bs, model): te_dpa = ( DotProductAttention( - config.num_attention_heads, - config.embed, + config.num_heads, + config.kv_channels, attention_dropout=0.0, # disable dropout, FU uses rng differently ) .to(dtype=dtype) @@ -1235,7 +1239,7 @@ def test_dpa_accuracy(dtype, bs, model): torch_dpa = ( TorchDotProductAttention( - config.embed, + config.kv_channels, 0.0, # dropout ) .to(dtype=dtype) @@ -1418,8 +1422,8 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ te_linear_ref, bs, dtype, config, delay_wgrad_compute=False ) - # Shoule be bit-wise match - for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): + # Should be bit-wise match + for _, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): torch.testing.assert_close(o, o_ref, rtol=0, atol=0) @@ -1431,17 +1435,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): fuse_wgrad_accumulation = True fp8_model_params = False fp8 = recipe is not None - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) + if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -1811,7 +1810,7 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("bs", [2]) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("normalization", all_normalizations) @@ -1881,14 +1880,12 @@ def test_fp8_layernorm_mlp_without_transpose_cache_accuracy(dtype, bs, model, ac @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("bs", [2]) @pytest.mark.parametrize("model", ["small"]) -@pytest.mark.parametrize("activation", all_activations) -@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) def test_layernorm_mlp_accuracy_delay_wgrad_compute( - dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation + dtype, bs, model, bias, fuse_wgrad_accumulation ): config = model_configs[model] @@ -1901,7 +1898,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ffn_hidden_size=4 * config.hidden_size, eps=config.eps, bias=bias, - normalization=normalization, params_dtype=dtype, device="cuda", delay_wgrad_compute=True, @@ -1913,7 +1909,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ffn_hidden_size=4 * config.hidden_size, eps=config.eps, bias=bias, - normalization=normalization, params_dtype=dtype, device="cuda", delay_wgrad_compute=False, @@ -1923,8 +1918,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( # Share params with torch.no_grad(): ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) - if normalization != "RMSNorm": - ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) + ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) if bias: @@ -1962,7 +1956,7 @@ def _test_grouped_linear_accuracy( FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -1975,14 +1969,14 @@ def _test_grouped_linear_accuracy( split_size = 16 if recipe.mxfp8(): split_size = 128 - m = config.seq_len // split_size + m = config.max_seqlen_q // split_size dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist.append(dist[-1]) # Manually add a zero m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) m_splits = m_splits * split_size - assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + assert m_splits.sum() == config.max_seqlen_q and len(m_splits) == num_gemms else: - m_splits = torch.tensor([config.seq_len]) + m_splits = torch.tensor([config.max_seqlen_q]) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, GroupedLinear): @@ -2036,23 +2030,18 @@ def test_grouped_linear_accuracy( bias, delay_wgrad_compute, parallel_mode=None, + use_cutlass=False, ): fp8 = recipe is not None if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and not fp8: pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -2117,9 +2106,47 @@ def test_grouped_linear_accuracy( delay_wgrad_compute, ) - # Shoule be bit-wise match - for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + for o, o_ref in zip(outputs, outputs_ref): + if use_cutlass: + torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + else: + # cuBLAS implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + +@pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0), + reason="Only enable CUTLASS grouped gemm on Hopper", +) +@pytest.mark.parametrize("dtype", param_types, ids=str) +@pytest.mark.parametrize("num_gemms", [3, 6]) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) +def test_grouped_linear_accuracy_cutlass( + dtype, + num_gemms, + bs, + model, + fuse_wgrad_accumulation, + delay_wgrad_compute, +): + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + test_grouped_linear_accuracy( + dtype, + num_gemms, + bs, + model, + None, + False, + fuse_wgrad_accumulation, + False, + delay_wgrad_compute, + None, + use_cutlass=True, + ) + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) @pytest.mark.parametrize("dtype", param_types, ids=str) @@ -2144,19 +2171,13 @@ def test_grouped_linear_accuracy_save_original_input( parallel_mode=None, ): fp8 = recipe is not None - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -2304,14 +2325,14 @@ def _generate_random_numbers(n, total_sum): FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( - (config.seq_len * bs, config.hidden_size), + (config.max_seqlen_q * bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) inp_hidden_states.retain_grad() - m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs) + m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, TorchGroupedLinearWithPadding): @@ -2354,17 +2375,11 @@ def test_padding_grouped_linear_accuracy( fp8_model_params, parallel_mode=None, ): - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -2429,19 +2444,13 @@ def test_padding_grouped_linear_accuracy_save_original_input( fp8_model_params, parallel_mode=None, ): - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -2498,9 +2507,11 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): # Placeholders used for graph capture. static_input = torch.randn( - config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True + config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True + ) + static_target = torch.randn( + config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype ) - static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype) real_input = torch.rand_like(static_input) real_target = torch.rand_like(static_target) @@ -2570,7 +2581,7 @@ def test_gpt_cuda_graph(dtype, bs, model): block_args = ( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, ) block_kwargs = dict( layernorm_epsilon=config.eps, @@ -2578,7 +2589,7 @@ def test_gpt_cuda_graph(dtype, bs, model): output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2613,13 +2624,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -2628,13 +2639,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ) te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) with fp8_autocast(enabled=True, fp8_recipe=recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) @@ -2654,14 +2665,8 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("recipe", fp8_recipes) def test_gpt_fp8_parameters(dtype, bs, model, recipe): - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -2696,13 +2701,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): block_sbhd = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0, attention_dropout=0, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2717,13 +2722,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): block_bshd = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0, attention_dropout=0, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2735,13 +2740,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): block_thd = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0, attention_dropout=0, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2756,15 +2761,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical" x_sbhd = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) x_bshd = x_sbhd.transpose(0, 1).contiguous() - x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous() - x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len + x_thd = x_bshd.reshape(bs * config.max_seqlen_q, config.hidden_size).contiguous() + x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.max_seqlen_q # To make sure forward is also identical (just in case some module decides # to act fancy) @@ -2809,192 +2814,27 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): x_thd, cu_seqlens_q=x_thd_cumsum, cu_seqlens_kv=x_thd_cumsum, - max_seqlen_q=config.seq_len, - max_seqlen_kv=config.seq_len, + max_seqlen_q=config.max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, ) if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): tols_thd = dtype_tols(dtype) - # On gfx950 the results for THD are different + # On gfx950 the results for THD are different # that results in lower final result precision tols_thd["atol"] = 2e-3 torch.testing.assert_close( y_bshd, - y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), + y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(), **tols_thd, ) else: torch.testing.assert_close( y_bshd, - y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), + y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(), ) -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model_key", model_configs_inference.keys()) -@pytest.mark.parametrize("use_RoPE", all_boolean) -@pytest.mark.parametrize("input_format", input_formats_inference) -@pytest.mark.parametrize("module", module_inference) -@pytest.mark.parametrize("backend", backends_inference) -@pytest.mark.parametrize("is_paged", [False, True]) -@pytest.mark.usefixtures("reset_test_envs") -def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged): - if ((backend == "FlashAttention" and os.getenv("NVTE_FLASH_ATTN", "1") == "0") or - (backend == "FusedAttention" and os.getenv("NVTE_FUSED_ATTN", "1") == "0")): - pytest.skip(f"{backend} is disabled") - - if backend == "FlashAttention" and not fa_utils.is_installed: - pytest.skip("FlashAttention is not installed") - - if IS_HIP_EXTENSION and backend == "FusedAttention": - if is_paged: - pytest.skip("FusedAttention does not support KV cache with paging on ROCm") - if os.getenv("NVTE_FUSED_ATTN_CK", "1") == "0": - pytest.skip("CK FusedAttention backend is disabled") - - reset_rng_states() - - if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32: - pytest.skip("FusedAttention and FlashAttention do not support FP32") - if use_RoPE: - pytest.skip("KV cache does not support starting positions for RoPE") - if ( - not IS_HIP_EXTENSION and - backend == "FusedAttention" - and get_device_compute_capability() == (8, 9) - and get_cudnn_version() < (9, 12, 0) - ): - pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12") - - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "0" - os.environ["NVTE_UNFUSED_ATTN"] = "0" - - if backend == "FlashAttention": - os.environ["NVTE_FLASH_ATTN"] = "1" - elif backend == "FusedAttention": - os.environ["NVTE_FUSED_ATTN"] = "1" - elif backend == "UnfusedAttention": - os.environ["NVTE_UNFUSED_ATTN"] = "1" - - config = model_configs_inference[model_key] - - S = config.seq_len - B = bs - H = config.num_attention_heads - D = config.hidden_size - head_size = config.embed - layer_number = 1 - - # Limits the max size of KV-cache - B_max = B - S_max = S - - if module == "TransformerLayer": - model = TransformerLayer( - hidden_size=D, - ffn_hidden_size=4 * D, - num_attention_heads=H, - attn_input_format=input_format, - self_attn_mask_type="causal", - enc_dec_attn_mask_type="causal", - layer_number=layer_number, - attention_dropout=0.0, - params_dtype=dtype, - device="cuda", - ).eval() - else: - model = ( - MultiheadAttention( - hidden_size=D, - num_attention_heads=H, - qkv_format=input_format, - layer_number=layer_number, - attention_dropout=0.0, - attn_mask_type="causal", - params_dtype=dtype, - ) - .cuda() - .eval() - ) - - inference_params = InferenceParams( - max_batch_size=B_max, - max_sequence_length=S_max, - num_heads_kv=H, - head_dim_k=head_size, - dtype=dtype, - is_paged=is_paged, - total_num_pages=int(B_max * S_max / 256), - page_size=256, - ) - - rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda") - - input = torch.randn((S, B, D), dtype=dtype, device="cuda") - if input_format == "bshd": - input = input.transpose(0, 1).contiguous() - - incremental_output = torch.zeros_like(input) - - # Generate output for the entire sequence - full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None) - - # Incrementaly generate outputs using KV-cache - step_dict = OrderedDict(zip(list(range(B)), [1] * B)) - for i in range(S): - inference_params.pre_step(step_dict) - - if input_format == "sbhd": - incremental_input = input[i].view(1, B, D) - else: - incremental_input = input[:, i, :].view(B, 1, D) - - seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda") - cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda") - cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) - cu_seqlens_kv = cu_seqlens_q.clone() - - mask_type = "padding" - kwargs = {} - if module == "TransformerLayer": - kwargs["self_attn_mask_type"] = mask_type - else: - kwargs["attn_mask_type"] = mask_type - line_output = model( - hidden_states=incremental_input, - inference_params=inference_params, - rotary_pos_emb=rotary_freqs if use_RoPE else None, - **kwargs, - max_seqlen_q=1, - max_seqlen_kv=S, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - ) - - if input_format == "sbhd": - incremental_output[i, :, :] = line_output.view(B, D) - else: - incremental_output[:, i, :] = line_output.view(B, D) - - if module == "TransformerLayer": - atol = { - torch.float32: 5e-3, - torch.half: 5e-3, - torch.bfloat16: 5e-2, - } - else: - atol = { - torch.float32: 1e-3, - torch.half: 1e-3, - torch.bfloat16: 1e-2, - } - - # Check if the fully generated output matches the one generated incrementally - assert_allclose(full_output, incremental_output, atol[dtype]) - - @pytest.mark.parametrize( "shape", [ @@ -3004,10 +2844,11 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, (16, 10027, 128, 512), ], ) -@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) @pytest.mark.parametrize("accumulate", [False, True]) -def test_grouped_gemm(shape, dtype, layout, accumulate): +@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm) +def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): torch.manual_seed(0) z, m, k, n = shape @@ -3042,6 +2883,9 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): grad = True single_output = False + if use_cutlass: + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + for i in range(z): general_gemm( A[i], @@ -3069,9 +2913,87 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): single_output=single_output, ) - # should be bit-wise match for o, o_ref in zip(out, out_ref): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + if not use_cutlass: + # cublas implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + else: + torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + + if use_cutlass: + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + + +@pytest.mark.parametrize("N", [32]) +@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + "input_quantizer", + [ + Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), + ], +) +@pytest.mark.parametrize( + "out_quantizer", + [ + Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), + Float8Quantizer( + torch.ones(1).cuda().squeeze(), torch.ones(1).cuda().squeeze(), tex.DType.kFloat8E4M3 + ), + ], +) +def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_quantizer): + # For MXFP8 and CurrentScaling, below unfused quantization should happen + # FP8 input --> cublas GEMM --> BF16 output --> Quantize to FP8 --> fp8 Output + # Skip invalid configurations + is_mxfp8_needed = isinstance(input_quantizer, MXFP8Quantizer) or isinstance( + out_quantizer, MXFP8Quantizer + ) + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if is_mxfp8_needed and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): + if isinstance(input_quantizer, MXFP8Quantizer): + N = math.ceil(N / 128) * 128 #hipBlasLt supports K which is multiple of 128 for MXFP8 + if not is_mxfp8_needed and isinstance(out_quantizer, Float8Quantizer): + pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.") + inp_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype)) + weight_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype)) + outp_type = torch.float32 + quantized_out, *_ = general_gemm( + weight_fp8, + inp_fp8, + get_workspace(), + outp_type, + quantization_params=out_quantizer, + bias=None, + use_split_accumulator=False, + ) + + out, *_ = general_gemm( + weight_fp8, + inp_fp8, + get_workspace(), + outp_type, + quantization_params=None, + bias=None, + use_split_accumulator=False, + ) + expected_quantized_out = out_quantizer(out) + + # Match results again Pytorch GEMM and allow for quantization tolerance + pytorch_out = torch.matmul( + inp_fp8.dequantize().to(torch.float64), + torch.transpose(weight_fp8.dequantize().to(torch.float64), 0, 1), + ) + fp8_tols = dict(rtol=0.125, atol=0.0675) + torch.testing.assert_close( + pytorch_out.to(outp_type), expected_quantized_out.dequantize(), **fp8_tols + ) + # Match results between quantization happening inside vs outside general_gemm + torch.testing.assert_close(expected_quantized_out.dequantize(), quantized_out.dequantize()) @pytest.mark.parametrize( @@ -3082,9 +3004,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): (16, 4096, 128, 512), ], ) -@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("accumulate", [False, True]) -def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): +def test_fp8_grouped_gemm(shape, accumulate): if not fp8_available: pytest.skip(reason_for_no_fp8) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index ea9c85e37..e5368497d 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -27,7 +27,6 @@ import numpy as np import onnxruntime as ort import torch -import random from torch import nn as nn from typing import Optional, Union, Tuple, List from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op @@ -37,6 +36,7 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import get_default_init_method +import tensorrt as trt # Global test configuration knobs. @@ -59,14 +59,14 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -skip_MXFP8 = pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) -fp8_recipes = [ - None, - recipe.DelayedScaling(), - recipe.MXFP8BlockScaling(), -] +fp8_recipes = [] +if mxfp8_available: + fp8_recipes.append(recipe.MXFP8BlockScaling()) +if fp8_available: + fp8_recipes.append(recipe.DelayedScaling()) + fp8_recipes.append(recipe.Float8CurrentScaling()) +fp8_recipes.append(None) supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] @@ -82,11 +82,11 @@ ], outputs=[PyCustomOpDef.dt_uint8], ) -def trt_fp8_quantize(t, scale): +def trt_fp8_quantize(t, scale_inv): """FP8 quantization extension for ONNX Runtime.""" x = torch.from_numpy(t).cuda() q = te.tensor.float8_tensor.Float8Quantizer( - scale=1 / torch.from_numpy(scale).cuda(), + scale=1 / torch.from_numpy(scale_inv).cuda(), amax=torch.zeros([1]).cuda(), fp8_dtype=tex.DType.kFloat8E4M3, ) @@ -102,11 +102,11 @@ def trt_fp8_quantize(t, scale): ], outputs=[PyCustomOpDef.dt_float], ) -def trt_fp8_dequantize(t, scale): +def trt_fp8_dequantize(t, scale_inv): """FP8 dequantization extension for ONNX Runtime.""" x = torch.from_numpy(t).cuda() q = te.tensor.float8_tensor.Float8Quantizer( - scale=1 / torch.from_numpy(scale).cuda(), + scale=1 / torch.from_numpy(scale_inv).cuda(), amax=torch.zeros([1]).cuda(), fp8_dtype=tex.DType.kFloat8E4M3, ) @@ -115,7 +115,7 @@ def trt_fp8_dequantize(t, scale): @onnx_op( - op_type="trt::TRT_MXFP8QuantizeLinear", + op_type="trt::TRT_MXFP8DynamicQuantize", domain="trt", inputs=[ PyCustomOpDef.dt_float, @@ -369,14 +369,6 @@ def create_ort_input_dict(session, inputs): ) -def create_meta(scale_factor: float, size: int = 1): - meta = tex.FP8TensorMeta() - meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") - meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor - meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor - return meta - - def dtype2str(dtype: torch.dtype, fake_bf16_io=False): if fake_bf16_io: assert dtype == torch.bfloat16 @@ -413,36 +405,12 @@ def get_attn_mask_str(use_mask, attn_mask_type): """ -@pytest.mark.parametrize("scale_factor", [112]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) -# Returning the bias is a TE fusion optimization we don't care about. -@pytest.mark.parametrize("return_bias", [True, False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, False), - (torch.float16, True), - # Todo: cannot configure BF16 when bias is disabled (ORT issue?) - (torch.bfloat16, False), - # Todo: cannot configure BF16 when bias is enabled (ORT issue?) - (torch.bfloat16, True), - ], -) -def test_export_linear( - seed_default_rng, - scale_factor: float, - fp8_recipe: recipe.Recipe, - use_bias: bool, - return_bias: bool, - precision: torch.dtype, +def _test_export_linear( + fp8_recipe: recipe.Recipe = fp8_recipes[0], + use_bias: bool = True, + return_bias: bool = False, + precision: torch.dtype = torch.float32, ): - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if return_bias and not use_bias: pytest.skip("Cannot return bias when bias is disabled") @@ -498,32 +466,28 @@ def forward(self, inp): ) -@pytest.mark.parametrize("scale_factor", [112]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize( - "precision", - [ - torch.float32, - torch.float16, - torch.bfloat16, - ], -) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("normalization", all_normalizations) -def test_export_layernorm( - seed_default_rng, - scale_factor: float, - fp8_recipe: recipe.Recipe, - precision: torch.dtype, - zero_centered_gamma: bool, - normalization: str, -): - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) +def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision): + _test_export_linear(fp8_recipe=fp8_recipe, precision=precision) + + +@pytest.mark.parametrize("use_bias", [True, False]) +def test_export_linear_use_bias(seed_default_rng, use_bias): + _test_export_linear(use_bias=use_bias) + + +@pytest.mark.parametrize("return_bias", [True, False]) +def test_export_linear_return_bias(seed_default_rng, return_bias): + _test_export_linear(return_bias=return_bias) + +def _test_export_layernorm( + fp8_recipe: recipe.Recipe = fp8_recipes[0], + precision: torch.dtype = torch.float32, + zero_centered_gamma: bool = False, + normalization: str = all_normalizations[0], +): # Set dimensions (these are arbitrary). batch_size = 4 in_features = 64 @@ -564,39 +528,31 @@ def test_export_layernorm( ) -@pytest.mark.parametrize("scale_factor", [112]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("return_bias", [True, False]) -@pytest.mark.parametrize("return_layernorm_output", [True, False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, True), - (torch.float16, False), - (torch.bfloat16, True), - (torch.bfloat16, False), - ], -) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) +def test_export_layernorm_recipe(seed_default_rng, fp8_recipe, precision): + _test_export_layernorm(fp8_recipe=fp8_recipe, precision=precision) + + +def test_export_layernorm_zero_centered_gamma(seed_default_rng): + _test_export_layernorm(zero_centered_gamma=True) + + @pytest.mark.parametrize("normalization", all_normalizations) -def test_export_layernorm_linear( - seed_default_rng, - scale_factor: float, - fp8_recipe: recipe.Recipe, - use_bias: bool, - return_bias: bool, - return_layernorm_output: bool, - precision: torch.dtype, - zero_centered_gamma: bool, - normalization: str, +def test_export_layernorm_normalization(seed_default_rng, normalization): + _test_export_layernorm(normalization=normalization) + + +def _test_export_layernorm_linear( + scale_factor: float = 112, + fp8_recipe: recipe.Recipe = fp8_recipes[0], + use_bias: bool = True, + return_bias: bool = False, + return_layernorm_output: bool = False, + precision: torch.dtype = torch.float32, + zero_centered_gamma: bool = False, + normalization: str = all_normalizations[0], ): - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if return_bias and not use_bias: pytest.skip("Cannot return bias when bias is disabled") @@ -638,47 +594,52 @@ def test_export_layernorm_linear( fname, inp, model, - atol=1e-3, + # For current scaling we use Float8Quantizer in tests + amax computed by hand, + # which has slightly different numerics than Float8CurrentScalingQuantizer. + atol=1e-3 if fp8_recipe.__class__ is not recipe.Float8CurrentScaling else 2e-2, is_fp8=fp8_recipe is not None, te_outputs=te_outputs, ) -@pytest.mark.parametrize("scale_factor", [112]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("return_bias", [True, False]) -@pytest.mark.parametrize("return_layernorm_output", [True, False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, True), - (torch.float16, False), - (torch.bfloat16, True), - (torch.bfloat16, False), - ], -) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("activation", supported_activations) -@pytest.mark.parametrize("normalization", all_normalizations) -def test_export_layernorm_mlp( - seed_default_rng, - scale_factor: float, - fp8_recipe: recipe.Recipe, - use_bias: bool, - return_bias: bool, - return_layernorm_output: bool, - precision: torch.dtype, - zero_centered_gamma: bool, - activation: str, - normalization: str, +@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) +def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision): + _test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision) + + +def test_export_layernorm_linear_return_ln_out(seed_default_rng): + _test_export_layernorm_linear(return_layernorm_output=True) + + +def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng): + _test_export_layernorm_linear(zero_centered_gamma=True) + + +@pytest.mark.parametrize("normalization", all_normalizations[1:]) +def test_export_layernorm_linear_normalization(seed_default_rng, normalization): + _test_export_layernorm_linear(normalization=normalization) + + +def test_export_layernorm_linear_no_bias(seed_default_rng): + _test_export_layernorm_linear(use_bias=False) + + +def test_export_layernorm_linear_return_bias(seed_default_rng): + _test_export_layernorm_linear(return_bias=True) + + +def _test_export_layernorm_mlp( + scale_factor: float = 112, + fp8_recipe: recipe.Recipe = fp8_recipes[0], + use_bias: bool = True, + return_bias: bool = False, + return_layernorm_output: bool = False, + precision: torch.dtype = torch.float32, + zero_centered_gamma: bool = False, + activation: str = supported_activations[0], + normalization: str = all_normalizations[0], ): - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if return_bias and not use_bias: pytest.skip("Cannot return bias when bias is disabled") @@ -720,6 +681,38 @@ def test_export_layernorm_mlp( ) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) +def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision): + _test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision) + + +def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng): + _test_export_layernorm_mlp(return_layernorm_output=True) + + +def test_export_layernorm_mlp_return_bias(seed_default_rng): + _test_export_layernorm_mlp(return_bias=True) + + +def test_export_layernorm_mlp_no_bias(seed_default_rng): + _test_export_layernorm_mlp(use_bias=False) + + +def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng): + _test_export_layernorm_mlp(zero_centered_gamma=True) + + +@pytest.mark.parametrize("normalization", all_normalizations[1:]) +def test_export_layernorm_mlp_normalization(seed_default_rng, normalization): + _test_export_layernorm_mlp(normalization=normalization) + + +@pytest.mark.parametrize("activation", supported_activations[1:]) +def test_export_layernorm_mlp_activation(seed_default_rng, activation): + _test_export_layernorm_mlp(activation=activation) + + @pytest.mark.parametrize( "precision, use_mask, attn_mask_type", [ @@ -734,8 +727,6 @@ def test_export_layernorm_mlp( ], ) def test_export_core_attention( - seed_default_rng, - set_max_seq_len, precision: torch.dtype, use_mask: bool, attn_mask_type: str, @@ -777,11 +768,6 @@ def test_export_core_attention( ) -test_configs_multihead_attention = [ - # "use_mask, attn_mask_type" - (False, "no_mask"), # calls ScaledSoftmax - (True, "arbitrary"), # calls ScaledMaskedSoftmax -] test_configs_attention_type = [ # "input_layernorm, attention_type, fuse_qkv_params" (True, "self", True), @@ -795,31 +781,14 @@ def test_export_core_attention( ] -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) -@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("return_layernorm_output", [False]) -@pytest.mark.parametrize( - "input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type -) -def test_export_multihead_attention( - seed_default_rng, - set_max_seq_len, - fp8_recipe: recipe.Recipe, - use_mask: bool, - attn_mask_type: str, - precision: torch.dtype, - return_layernorm_output: bool, - input_layernorm: bool, - attention_type: str, - fuse_qkv_params: bool, +def _test_export_multihead_attention( + fp8_recipe: recipe.Recipe = fp8_recipes[0], + use_mask: bool = True, + precision: torch.dtype = torch.float32, + input_layernorm: bool = True, + attention_type: str = "self", + fuse_qkv_params: bool = True, ): - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - hidden_size = 256 sequence_length = 128 batch_size = 4 @@ -837,6 +806,7 @@ def test_export_multihead_attention( init_method, output_layer_init_method, ) + attn_mask_type = "arbitrary" if use_mask else "no_mask" hidden_states_context = torch.randn( sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" @@ -868,7 +838,7 @@ def test_export_multihead_attention( *attention_args, attn_mask_type=attn_mask_type, params_dtype=precision, - return_layernorm_output=return_layernorm_output, + return_layernorm_output=False, input_layernorm=input_layernorm, attention_type=attention_type, fuse_qkv_params=fuse_qkv_params, @@ -960,30 +930,37 @@ def test_export_multihead_attention( @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) -@pytest.mark.parametrize("output_layernorm", [True, False]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("fuse_qkv_params", [False, True]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("activation", supported_activations) -def test_export_transformer_layer( - seed_default_rng, - set_max_seq_len, - fp8_recipe: recipe.Recipe, - use_mask: bool, - attn_mask_type: str, - output_layernorm: bool, - precision: torch.dtype, - fuse_qkv_params: bool, - zero_centered_gamma: bool, - activation: str, -): - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) +def test_export_multihead_attention_recipe(fp8_recipe, precision): + _test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision) + + +def test_export_multihead_attention_no_mask(): + _test_export_multihead_attention(use_mask=False) + + +def test_export_multihead_attention_no_input_layernorm(): + _test_export_multihead_attention(input_layernorm=False) + + +def test_export_multihead_attention_cross_attn(): + _test_export_multihead_attention(attention_type="cross") + +def test_export_multihead_attention_unfused_qkv_params(): + _test_export_multihead_attention(fuse_qkv_params=False) + + +def _test_export_transformer_layer( + fp8_recipe: recipe.Recipe = fp8_recipes[0], + use_mask: bool = True, + attn_mask_type: str = "arbitrary", + output_layernorm: bool = False, + precision: torch.dtype = torch.float32, + fuse_qkv_params: bool = True, + zero_centered_gamma: bool = False, + activation: str = supported_activations[0], +): # Layer configuration hidden_size = 64 sequence_length = 128 @@ -1043,28 +1020,43 @@ def test_export_transformer_layer( ) -@skip_FP8 -@skip_MXFP8 +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) +def test_export_transformer_layer_recipe(fp8_recipe, precision): + _test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision) + + +def test_export_transformer_layer_no_mask(): + _test_export_transformer_layer(use_mask=False) + + +def test_export_transformer_layer_output_layernorm(): + _test_export_transformer_layer(output_layernorm=True) + + +def test_export_transformer_layer_unfused_qkv_params(): + _test_export_transformer_layer(fuse_qkv_params=False) + + +def test_export_transformer_layer_zero_centered_gamma(): + _test_export_transformer_layer(zero_centered_gamma=True) + + +@pytest.mark.parametrize("activation", supported_activations[1:]) +def test_export_transformer_layer_activation(activation): + _test_export_transformer_layer(activation=activation) + + @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("zero_centered_gamma", [True]) def test_export_gpt_generation( - seed_default_rng, - set_max_seq_len, fp8_recipe: recipe.Recipe, precision: torch.dtype, - zero_centered_gamma: bool, ): """Test that the ONNX model can correctly handle inputs with different shapes and that the attention mask is adjusted on-the-fly to different sequence lengths. """ - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - # Layer configuration hidden_size = 64 sequence_length = 128 @@ -1091,7 +1083,6 @@ def test_export_gpt_generation( output_layernorm=output_layernorm, params_dtype=precision, fuse_qkv_params=fuse_qkv_params, - zero_centered_gamma=zero_centered_gamma, ).to(device="cuda") # "Context phase": use full input sequence length @@ -1152,3 +1143,64 @@ def test_export_ctx_manager(enabled): with te.onnx_export(enabled): assert is_in_onnx_export_mode() == enabled assert is_in_onnx_export_mode() == False + + +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +def test_trt_integration(fp8_recipe: recipe.Recipe): + + model = te.TransformerLayer( + hidden_size=128, + ffn_hidden_size=128, + num_attention_heads=4, + ).eval() + + if type(fp8_recipe) == recipe.Float8CurrentScaling: + # TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling + model = te.LayerNormMLP(128, 128) + + inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),) + + with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + out_ref = model(*inps) + + onnx_fd, onnx_path = tempfile.mkstemp(suffix=".onnx") + os.close(onnx_fd) + try: + with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.onnx_export(enabled=True): + torch.onnx.export( + model, + inps, + onnx_path, + output_names=["output"], + dynamo=True, + custom_translation_table=te_translation_table, + ) + + os.system(f"trtexec --onnx={onnx_path} --saveEngine={onnx_path}.engine") + + # Run TRT engine + logger = trt.Logger(trt.Logger.WARNING) + runtime = trt.Runtime(logger) + with open(onnx_path + ".engine", "rb") as f: + engine_data = f.read() + engine = runtime.deserialize_cuda_engine(engine_data) + context = engine.create_execution_context() + context.set_tensor_address(engine.get_tensor_name(0), inps[0].data_ptr()) + stream = torch.cuda.Stream() + + out = torch.zeros_like(out_ref) + context.set_tensor_address("output", out.data_ptr()) + + context.execute_async_v3(stream_handle=stream.cuda_stream) + stream.synchronize() + + # Compare TRT and TE outputs + atol = 5e-2 if fp8_recipe is not None else 1e-4 + rtol = 5e-2 if fp8_recipe is not None else 1e-4 + torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol) + finally: + try: + os.remove(onnx_path) + except FileNotFoundError: + pass diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index dd6c6a3b0..fa56852ff 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -3,10 +3,11 @@ # See LICENSE for license information. import random -import pytest import torch from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from utils import dtype_tols + class TestParallelCrossEntropy: @@ -19,19 +20,25 @@ def generate_infra(self, reduce_loss: bool, label_smoothing: float): label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" ) - def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): - + def generate_input( + self, + dtype: torch.dtype, + swap_dim: bool, + ignore_idx: bool, + device: torch.device = "cuda", + ): SQ = random.choice([64, 128]) batch = random.choice([1, 2]) vocab = random.choice([64000, 128000]) ignore = random.sample(range(0, SQ - 1), 5) + # Generate random data if swap_dim: - self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda() + self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (SQ, batch), device=device) else: - self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda() + self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (batch, SQ), device=device) if ignore_idx: for i in ignore: @@ -41,9 +48,14 @@ def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): else: self.tar_test[0][i] = -100 + # Make copy of data for reference implementation self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab)) self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,)) + # Enable autograd + self.input_test.requires_grad_() + self.input_ref.requires_grad_() + def one_iteration_test( self, dtype: torch.dtype, @@ -53,18 +65,20 @@ def one_iteration_test( ignore_idx: bool = False, ): + # Random data self.generate_input(dtype, swap_dim, ignore_idx) - self.input_test.requires_grad_(True) - self.input_ref.requires_grad_(True) - + # Forward pass test_loss = self.test_loss_func( self.input_test, self.tar_test, label_smoothing, reduce_loss, None ) - ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) - # Handle backward pass based on the test scenario + # Compute square to avoid trivial backward pass + test_loss = torch.square(test_loss) + ref_loss = torch.square(ref_loss) + + # Backward pass if reduce_loss: test_loss.backward() ref_loss.backward() @@ -72,16 +86,18 @@ def one_iteration_test( test_loss.sum().backward() ref_loss.sum().backward() - test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss - - if ignore_idx: - print(test_loss, ref_loss) - - # Compare gradients when backward pass was called - torch.testing.assert_close( - torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad - ) - + # Check that loss and grad input match + tols = dtype_tols(dtype) + test_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = ref_loss.reshape(test_loss.size()) + test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = ref_grad_input.reshape(test_grad_input.size()) + torch.testing.assert_close(test_loss, ref_loss, **tols) + torch.testing.assert_close(test_grad_input, ref_grad_input, **tols) + + # Reset data self.input_test = None self.input_ref = None self.tar_test = None diff --git a/tests/pytorch/test_qk_norm.py b/tests/pytorch/test_qk_norm.py index 6f4e62f81..d45ec283c 100644 --- a/tests/pytorch/test_qk_norm.py +++ b/tests/pytorch/test_qk_norm.py @@ -8,10 +8,10 @@ import torch -@pytest.mark.parametrize("use_qk_norm", [False, True]) +@pytest.mark.parametrize("qk_norm_type", [None, "L2Normalization", "RMSNorm", "LayerNorm"]) @pytest.mark.parametrize("attention_type", ["self", "cross"]) @pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5]) -def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None: +def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> None: """Test QK normalization functionality, module structure, and numerical behavior.""" hidden_size = 256 num_attention_heads = 8 @@ -22,25 +22,59 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_type=attention_type, - use_qk_norm=use_qk_norm, + qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, bias=False, device="cuda", ).cuda() - # Check module structure based on use_qk_norm parameter - if use_qk_norm: - assert hasattr(mha, "qk_norm"), "Should have qk_norm module when use_qk_norm=True" - assert not hasattr(mha, "q_l2norm"), "Should not have separate q_l2norm module" - assert not hasattr(mha, "k_l2norm"), "Should not have separate k_l2norm module" - # Check that the module is L2Norm type - from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization - - assert isinstance( - mha.qk_norm, L2Normalization - ), "qk_norm should be an L2Normalization module" + # Check module structure based on qk_norm_type parameter + if qk_norm_type is not None: + assert mha.q_norm is not None, "Should have q_norm module when qk_norm_type is not None" + assert mha.k_norm is not None, "Should have k_norm module when qk_norm_type is not None" + + # Check that the modules are of the correct type + if qk_norm_type == "L2Normalization": + from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization + + assert isinstance( + mha.q_norm, L2Normalization + ), "q_norm should be an L2Normalization module" + assert isinstance( + mha.k_norm, L2Normalization + ), "k_norm should be an L2Normalization module" + # For L2 normalization, q_norm and k_norm should be the same instance (parameter-free) + assert ( + mha.q_norm is mha.k_norm + ), "q_norm and k_norm should be the same instance for L2 normalization" + + elif qk_norm_type == "RMSNorm": + from transformer_engine.pytorch.module.rmsnorm import RMSNorm + + assert isinstance(mha.q_norm, RMSNorm), "q_norm should be an RMSNorm module" + assert isinstance(mha.k_norm, RMSNorm), "k_norm should be an RMSNorm module" + # For RMS normalization, q_norm and k_norm should be separate instances + assert ( + mha.q_norm is not mha.k_norm + ), "q_norm and k_norm should be separate instances for RMS normalization" + + elif qk_norm_type == "LayerNorm": + from transformer_engine.pytorch.module.layernorm import LayerNorm + + assert isinstance(mha.q_norm, LayerNorm), "q_norm should be a LayerNorm module" + assert isinstance(mha.k_norm, LayerNorm), "k_norm should be a LayerNorm module" + # For LayerNorm, q_norm and k_norm should be separate instances + assert ( + mha.q_norm is not mha.k_norm + ), "q_norm and k_norm should be separate instances for LayerNorm" + + else: + # For extensibility - just ensure they exist + assert mha.q_norm is not None, f"q_norm should exist for qk_norm_type={qk_norm_type}" + assert mha.k_norm is not None, f"k_norm should exist for qk_norm_type={qk_norm_type}" else: - assert not hasattr(mha, "qk_norm"), "Should not have qk_norm module when use_qk_norm=False" + assert mha.q_norm is None, "Should not have q_norm module when qk_norm_type is None" + assert mha.k_norm is None, "Should not have k_norm module when qk_norm_type is None" # Create input tensors batch_size = 2 # Use a fixed batch size for testing @@ -89,17 +123,14 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None assert not torch.isinf(output_with_rope).any(), "RoPE output contains Inf" -def test_qk_norm_output_difference() -> None: +@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"]) +def test_qk_norm_output_difference(qk_norm_type) -> None: """Test that QK normalization actually changes the output compared to no normalization.""" hidden_size = 256 num_attention_heads = 8 seq_len = 128 batch_size = 2 - # Use same random seed to ensure identical weight initialization - current_rng_state = torch.get_rng_state() - current_cuda_rng_state = torch.cuda.get_rng_state() - # Reset to a known seed for reproducible initialization torch.manual_seed(42) torch.cuda.manual_seed(42) @@ -108,7 +139,7 @@ def test_qk_norm_output_difference() -> None: mha_with_norm = MultiheadAttention( hidden_size=hidden_size, num_attention_heads=num_attention_heads, - use_qk_norm=True, + qk_norm_type=qk_norm_type, bias=False, device="cuda", ).cuda() @@ -121,7 +152,7 @@ def test_qk_norm_output_difference() -> None: mha_no_norm = MultiheadAttention( hidden_size=hidden_size, num_attention_heads=num_attention_heads, - use_qk_norm=False, + qk_norm_type=None, bias=False, device="cuda", ).cuda() @@ -139,10 +170,11 @@ def test_qk_norm_output_difference() -> None: # Outputs should be different when QK normalization is enabled assert not torch.allclose( output_with_norm, output_no_norm, atol=1e-6 - ), "QK normalization should change the output, but outputs are identical" + ), f"QK normalization ({qk_norm_type}) should change the output, but outputs are identical" -def test_qk_norm_with_fused_qkv() -> None: +@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"]) +def test_qk_norm_with_fused_qkv(qk_norm_type) -> None: """Test QK normalization works with fused QKV parameters.""" hidden_size = 256 num_attention_heads = 8 @@ -152,7 +184,7 @@ def test_qk_norm_with_fused_qkv() -> None: hidden_size=hidden_size, num_attention_heads=num_attention_heads, fuse_qkv_params=True, - use_qk_norm=True, + qk_norm_type=qk_norm_type, bias=False, device="cuda", ).cuda() @@ -173,7 +205,8 @@ def test_qk_norm_with_fused_qkv() -> None: ), f"Output shape mismatch: {output.shape}" -def test_qk_norm_transformer_layer_output_difference() -> None: +@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"]) +def test_qk_norm_transformer_layer_output_difference(qk_norm_type) -> None: """Test that QK normalization actually changes TransformerLayer output compared to no normalization.""" from transformer_engine.pytorch import TransformerLayer @@ -183,10 +216,6 @@ def test_qk_norm_transformer_layer_output_difference() -> None: seq_len = 128 batch_size = 2 - # Use same random seed to ensure identical weight initialization - current_rng_state = torch.get_rng_state() - current_cuda_rng_state = torch.cuda.get_rng_state() - # Reset to a known seed for reproducible initialization torch.manual_seed(42) torch.cuda.manual_seed(42) @@ -196,7 +225,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None: hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_attention_heads=num_attention_heads, - use_qk_norm=True, + qk_norm_type=qk_norm_type, bias=False, device="cuda", ).cuda() @@ -210,7 +239,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None: hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_attention_heads=num_attention_heads, - use_qk_norm=False, + qk_norm_type=None, bias=False, device="cuda", ).cuda() @@ -226,9 +255,10 @@ def test_qk_norm_transformer_layer_output_difference() -> None: output_no_norm = transformer_no_norm(hidden_states) # Outputs should be different when QK normalization is enabled - assert not torch.allclose( - output_with_norm, output_no_norm, atol=1e-6 - ), "QK normalization should change the TransformerLayer output, but outputs are identical" + assert not torch.allclose(output_with_norm, output_no_norm, atol=1e-6), ( + f"QK normalization ({qk_norm_type}) should change the TransformerLayer output, but outputs" + " are identical" + ) # Check that outputs have expected shapes and properties assert output_with_norm.shape == ( @@ -240,3 +270,120 @@ def test_qk_norm_transformer_layer_output_difference() -> None: assert not torch.isinf(output_with_norm).any(), "Output with QK norm contains Inf" assert not torch.isnan(output_no_norm).any(), "Output without QK norm contains NaN" assert not torch.isinf(output_no_norm).any(), "Output without QK norm contains Inf" + + +@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"]) +def test_qk_norm_before_after_rope(qk_norm_type) -> None: + """Test that QK normalization before and after RoPE works without errors.""" + hidden_size = 256 + num_attention_heads = 8 + seq_len = 64 + batch_size = 2 + + # Create model with QK norm after RoPE (default) + mha_after = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + qk_norm_type=qk_norm_type, + qk_norm_before_rope=False, + bias=False, + device="cuda", + ).cuda() + + # Create model with QK norm before RoPE + mha_before = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + qk_norm_type=qk_norm_type, + qk_norm_before_rope=True, + bias=False, + device="cuda", + ).cuda() + + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + # Create RoPE embeddings + head_dim = hidden_size // num_attention_heads + rotary_dim = head_dim // 2 + rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32) + + with torch.no_grad(): + output_after_rope = mha_after(hidden_states, rotary_pos_emb=rotary_pos_emb) + output_before_rope = mha_before(hidden_states, rotary_pos_emb=rotary_pos_emb) + + output_after_no_rope = mha_after(hidden_states) + output_before_no_rope = mha_before(hidden_states) + + # Check output shapes and properties + expected_shape = (seq_len, batch_size, hidden_size) + for output in [ + output_after_rope, + output_before_rope, + output_after_no_rope, + output_before_no_rope, + ]: + assert output.shape == expected_shape, f"Output shape mismatch: {output.shape}" + assert not torch.isnan(output).any(), "Output contains NaN" + assert not torch.isinf(output).any(), "Output contains Inf" + + assert output_after_rope.shape == output_before_rope.shape, "Outputs should have same shape" + assert mha_after.qk_norm_before_rope == False, "mha_after should have qk_norm_before_rope=False" + assert mha_before.qk_norm_before_rope == True, "mha_before should have qk_norm_before_rope=True" + + +def test_different_qk_norm_types_produce_different_outputs() -> None: + """Test that different QK normalization types produce different outputs.""" + hidden_size = 256 + num_attention_heads = 8 + seq_len = 128 + batch_size = 2 + + # Use same random seed to ensure identical weight initialization + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Create model with L2 normalization + mha_l2 = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + qk_norm_type="L2Normalization", + bias=False, + device="cuda", + ).cuda() + + # Reset to same seed for identical initialization + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Create model with RMS normalization + mha_rms = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + qk_norm_type="RMSNorm", + bias=False, + device="cuda", + ).cuda() + + # Create input tensors + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + # Compare outputs with identical weights but different QK norm types + with torch.no_grad(): + output_l2 = mha_l2(hidden_states) + output_rms = mha_rms(hidden_states) + + # Outputs should be different when using different normalization types + assert not torch.allclose( + output_l2, output_rms, atol=1e-6 + ), "L2 and RMS normalization should produce different outputs, but outputs are identical" + + # Check that outputs have expected shapes and properties + assert output_l2.shape == output_rms.shape, "L2 and RMS outputs should have same shape" + assert not torch.isnan(output_l2).any(), "L2 output contains NaN" + assert not torch.isinf(output_l2).any(), "L2 output contains Inf" + assert not torch.isnan(output_rms).any(), "RMS output contains NaN" + assert not torch.isinf(output_rms).any(), "RMS output contains Inf" diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 5aa91de52..c59bf376a 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -194,12 +194,6 @@ def test_fp8_scale_update_with_linear_fuser_op( amax_compute_algo=amax_compute_algo, ) - # Get FP8 meta tensors - with te.fp8_autocast(fp8_recipe=recipe): - x_fp8_meta = op.get_quantizer("forward", 0) - w_fp8_meta = op.get_quantizer("forward", 1) - dy_fp8_meta = op.get_quantizer("backward", 0) - # Perform training steps x_history = [] w_history = [] @@ -231,19 +225,30 @@ def test_fp8_scale_update_with_linear_fuser_op( y = op(x) y.backward(dy) - def check_amax_history( - fp8_meta: dict, - ref_amax_history: Iterable[float], - ) -> None: - """Check that amax history matches expected values""" - if len(ref_amax_history) > amax_history_len: - ref_amax_history = ref_amax_history[-amax_history_len:] + def check_metas( + test_scale: float, + test_amax_history: torch.Tensor, + ref_amax_history_list: list[float], + stage: str, + ): + """Check that meta tensors match expected values""" + + # Compute amax + if len(ref_amax_history_list) > amax_history_len: + ref_amax_history_list = ref_amax_history_list[-(amax_history_len + 1) :] ref_amax_history = torch.tensor( - ref_amax_history, + ref_amax_history_list, dtype=torch.float32, device=device, ) - test_amax_history = fp8_meta.amax_history[:, 0] + if amax_compute_algo == "max": + ref_amax = max(ref_amax_history_list) + elif amax_compute_algo == "most_recent": + ref_amax = ref_amax_history_list[-1] + else: + raise RuntimeError(f"{amax_compute_algo=} is not supported") + + # Compare amax history tols = dict(rtol=0, atol=0) torch.testing.assert_close( test_amax_history[-(step + 1) :], @@ -251,23 +256,6 @@ def check_amax_history( **tols, ) - def check_scale( - quantizer: Float8Quantizer, - ref_amax_history: Iterable[float], - stage: str, - ): - """Check that scale and scale reciprocal match expected values""" - - # Compute amax - if len(ref_amax_history) > amax_history_len: - ref_amax_history = ref_amax_history[-(amax_history_len + 1) :] - if amax_compute_algo == "max": - ref_amax = max(ref_amax_history) - elif amax_compute_algo == "most_recent": - ref_amax = ref_amax_history[-1] - else: - raise RuntimeError(f"{amax_compute_algo=} is not supported") - # Compute scale max_val = { "forward": 448.0 if not is_fp8_fnuz() else 240.0, @@ -275,16 +263,26 @@ def check_scale( }[stage] ref_scale = (max_val / ref_amax) / (2**margin) - # Check values in FP8 meta tensors + # Compare scale torch.testing.assert_close( - quantizer.scale.item(), + test_scale, ref_scale, ) + # Get scaling factors + x_test_scale = op.get_quantizer("forward", 0).scale.item() + w_test_scale = op.get_quantizer("forward", 1).scale.item() + dy_test_scale = op.get_quantizer("backward", 0).scale.item() + + # Get amax histories + x_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 0] + w_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 1] + dy_test_history = op._fp8_metas["backward"][backward_key].amax_history[:, 0] + # Check that results match expected values - check_scale(x_fp8_meta, x_history, "forward") - check_scale(w_fp8_meta, w_history, "forward") - check_scale(dy_fp8_meta, dy_history, "backward") + check_metas(x_test_scale, x_test_history, x_history, "forward") + check_metas(w_test_scale, w_test_history, w_history, "forward") + check_metas(dy_test_scale, dy_test_history, dy_history, "backward") @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"]) @pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"]) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 9fbadd4b9..a7d762c3d 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -1,13 +1,10 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -import os -from dataclasses import dataclass from typing import Optional -from contextlib import nullcontext import torch import pytest @@ -21,11 +18,9 @@ fp8_model_init, ) from transformer_engine.pytorch.utils import ( - get_device_compute_capability, init_method_normal, scaled_init_method_normal, is_bf16_compatible, - get_cudnn_version, ) from transformer_engine.pytorch import ( LayerNormLinear, @@ -35,7 +30,6 @@ TransformerLayer, RMSNorm, LayerNorm, - get_cpu_offload_context, ) from transformer_engine.common import recipe import transformer_engine_torch as tex @@ -50,21 +44,17 @@ from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint -from utils import dtype_tols +from utils import ModelConfig # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() -) +fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() # Record initial RNG state from script run. seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0")) @@ -82,82 +72,33 @@ ) -def create_meta(scale_factor: float, size: int = 1): - meta = tex.FP8TensorMeta() - meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") - meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor - meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor - return meta - - -def custom_amax_to_scale( - amax: torch.Tensor, - scale: torch.Tensor, - fp8_max: torch.Tensor, - recipe: recipe.DelayedScaling, -) -> torch.Tensor: - """Custom func to test recipe.""" - sf = fp8_max / amax - sf = torch.where(amax > 0.0, sf, scale) - sf = torch.where(torch.isfinite(amax), sf, scale) - - return sf - - -def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor: - """Custom func to test recipe.""" - return torch.min(amax_history, dim=0).values - - -def reset_rng_states() -> None: - """revert back to initial RNG state.""" - global _cpu_rng_state, _cuda_rng_state - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) - - -@dataclass -class ModelConfig: - """Transformer model configuration""" - - num_layers: int - seq_len: int - batch_size: int - hidden_size: int - num_attention_heads: int - kv_channels: Optional[int] = None - - def is_fp8_supported(self): - if self.seq_len * self.batch_size % 16: - return False - if self.hidden_size % 16: - return False - return True +def is_fp8_supported(config: ModelConfig): + if ( + config.max_seqlen_q * config.batch_size % 16 + or config.max_seqlen_kv * config.batch_size % 16 + ): + return False + if config.hidden_size % 16 or config.hidden_size_kv % 16: + return False + return True model_configs = { - "126m": ModelConfig(12, 2048, 2, 768, 12), - "small": ModelConfig(2, 32, 2, 64, 2), - "weird": ModelConfig(2, 37, 3, 69, 3), - "large": ModelConfig(1, 128, 2, 512, 4, 128), + "126m": ModelConfig(2, 2048, 12, 64, num_layers=12), + "small": ModelConfig(2, 32, 2, 32, num_layers=2), + "weird": ModelConfig(3, 37, 3, 23, num_layers=2), + "large": ModelConfig(2, 128, 4, 128, num_layers=1), } -fp8_recipes = [ - None, # Test non-FP8 - recipe.MXFP8BlockScaling(), # Test default - recipe.Float8CurrentScaling(), # Test default - recipe.Float8BlockScaling(), # Test default - recipe.DelayedScaling(), # Test default - recipe.DelayedScaling( # Test most_recent algo - amax_history_len=16, - amax_compute_algo="most_recent", - ), - recipe.DelayedScaling( # Test custom amax and scale compute algo - fp8_format=recipe.Format.E4M3, - amax_compute_algo=custom_amax_compute, - scaling_factor_compute_algo=custom_amax_to_scale, - ), -] +fp8_recipes = [] +if mxfp8_available: + fp8_recipes.append(recipe.MXFP8BlockScaling()) +if fp8_block_scaling_available: + fp8_recipes.append(recipe.Float8BlockScaling()) +if fp8_available: + fp8_recipes.append(recipe.Float8CurrentScaling()) + fp8_recipes.append(recipe.DelayedScaling()) +fp8_recipes.append(None) param_types = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher @@ -166,7 +107,18 @@ def is_fp8_supported(self): all_boolean = [True, False] batch_sizes_with_zero = [0, 1, 2] -all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"] +all_activations = [ + "gelu", + "geglu", + "qgelu", + "qgeglu", + "relu", + "reglu", + "srelu", + "sreglu", + "silu", + "swiglu", +] all_normalizations = ["LayerNorm", "RMSNorm"] @@ -181,67 +133,9 @@ def reset_global_fp8_state(): FP8GlobalStateManager.reset() -def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): - # Initialize loss function and optimizer. - loss_fn = torch.nn.MSELoss() - optimizer = torch.optim.SGD(block.parameters(), lr=0.1) - - # Placeholders used for capture. - static_input = torch.randn( - config.seq_len, - config.batch_size, - config.hidden_size, - device="cuda", - dtype=dtype, - requires_grad=True, - ) - static_target = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype - ) - - real_input = torch.rand_like(static_input) - real_target = torch.rand_like(static_target) - - - use_fp8 = fp8_recipe is not None - if skip_wgrad: - _disable_wgrads(block) - - # Pre graph capture warmup in a separate stream. - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - for _ in range(3): - optimizer.zero_grad(set_to_none=True) - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True): - out = block(static_input) - loss = loss_fn(out, static_target) - loss.backward() - optimizer.step() - torch.cuda.current_stream().wait_stream(s) - - # Capture. - g = torch.cuda.CUDAGraph() - optimizer.zero_grad(set_to_none=True) - with torch.cuda.graph(g): - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True): - static_output = block(static_input) - static_loss = loss_fn(static_output, static_target) - static_loss.backward() - optimizer.step() - - # Fills the graph's input memory with new data to compute on - with torch.no_grad(): - static_input.copy_(real_input) - static_target.copy_(real_target) - g.replay() - - torch.cuda.synchronize() - - def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=torch.float32, device="cuda", requires_grad=True, @@ -249,7 +143,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states.retain_grad() te_inp_attn_mask = torch.randint( 2, - (1, 1, config.seq_len, config.seq_len), + (1, 1, config.max_seqlen_q, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) @@ -276,14 +170,14 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_attn_mask = torch.randint( 2, - (1, 1, config.seq_len, config.seq_len), + (1, 1, config.max_seqlen_q, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) @@ -314,9 +208,9 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}." -def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): +def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -325,16 +219,9 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): if skip_wgrad: _disable_wgrads(block) - if cpu_offload: - offload_context, sync_function = get_cpu_offload_context(enabled=True) - else: - offload_context = nullcontext() - sync_function = lambda x: x - use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context: + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): te_out = block(te_inp_hidden_states) - te_out = sync_function(te_out) loss = te_out.sum() loss.backward() torch.cuda.synchronize() @@ -342,7 +229,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -350,7 +237,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_attn_mask = torch.randint( 2, - (config.batch_size, 1, 1, config.seq_len), + (config.batch_size, 1, 1, config.max_seqlen_q), dtype=torch.bool, device="cuda", ) @@ -368,21 +255,21 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_attn_mask = torch.randint( 2, - (1, 1, config.seq_len, config.seq_len), + (1, 1, config.max_seqlen_q, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) enc_dec_attn_mask = torch.randint( 2, - (config.batch_size, 1, 1, config.seq_len), + (config.batch_size, 1, 1, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) @@ -410,7 +297,7 @@ def _test_sanity_common( pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") te_inp = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=not skip_dgrad, @@ -438,7 +325,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad) pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") te_inp = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), device="cuda", requires_grad=True, ) @@ -493,13 +380,7 @@ def test_sanity_layernorm_linear( config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -527,13 +408,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -560,16 +435,10 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ pytest.skip("Quantized model parameters are not supported in debug mode.") config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size - num_tokens = bs * config.seq_len + num_tokens = bs * config.max_seqlen_q if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") use_fp8 = fp8_recipe is not None @@ -605,16 +474,10 @@ def test_sanity_grouped_linear( ffn_hidden_size = 4 * config.hidden_size # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. bs = bs * 16 - num_tokens = bs * config.seq_len * (num_gemms - 1) + num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") use_fp8 = fp8_recipe is not None @@ -626,7 +489,7 @@ def test_sanity_grouped_linear( inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True ).cuda() - m_splits = [bs * config.seq_len] * num_gemms + m_splits = [bs * config.max_seqlen_q] * num_gemms if empty_split == "first": m_splits[0] = 0 elif empty_split == "last": @@ -664,13 +527,7 @@ def test_sanity_layernorm_mlp( config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -695,39 +552,24 @@ def test_sanity_layernorm_mlp( @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) -@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("bias", all_boolean) -@pytest.mark.parametrize("activation", all_activations) +@pytest.mark.parametrize("activation", ["gelu", "swiglu"]) @pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) -@pytest.mark.parametrize("cpu_offload", all_boolean) def test_sanity_gpt( dtype, fp8_recipe, model, skip_wgrad, - zero_centered_gamma, bias, activation, normalization, parallel_attention_mlp, - cpu_offload, ): - if IS_HIP_EXTENSION and cpu_offload: - pytest.skip("cpu_offloading not supported in rocm TE") - - if cpu_offload and NVTE_TEST_NVINSPECT_ENABLED: - pytest.skip("CPU offload is not supported in debug mode.") config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -737,7 +579,7 @@ def test_sanity_gpt( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -746,7 +588,6 @@ def test_sanity_gpt( params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, - zero_centered_gamma=zero_centered_gamma, bias=bias, activation=activation, normalization=normalization, @@ -754,7 +595,7 @@ def test_sanity_gpt( parallel_attention_mlp=parallel_attention_mlp, ) - _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload) + _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad) def test_sanity_gpt_126m(): @@ -771,12 +612,10 @@ def test_sanity_gpt_126m(): fp8_recipe=fp8_recipe, model="126m", skip_wgrad=False, - zero_centered_gamma=True, bias=True, activation="gelu", normalization="LayerNorm", parallel_attention_mlp=False, - cpu_offload=False, ) @@ -784,19 +623,14 @@ def test_sanity_gpt_126m(): @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) -@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) -def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization): +def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization): config = model_configs[model] if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -806,7 +640,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -815,7 +649,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, params_dtype=dtype, apply_residual_connection_post_layernorm=True, output_layernorm=True, - zero_centered_gamma=zero_centered_gamma, self_attn_mask_type="causal", normalization=normalization, device="cuda", @@ -836,7 +669,6 @@ def test_sanity_bert_126m(): fp8_recipe=fp8_recipe, model="126m", skip_wgrad=False, - zero_centered_gamma=False, normalization="LayerNorm", ) @@ -845,19 +677,14 @@ def test_sanity_bert_126m(): @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) -@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) -def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization): +def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization): config = model_configs[model] if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -867,7 +694,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -877,7 +704,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no apply_residual_connection_post_layernorm=False, output_layernorm=False, layer_type="decoder", - zero_centered_gamma=zero_centered_gamma, normalization=normalization, device="cuda", ) @@ -897,7 +723,6 @@ def test_sanity_T5_126m(): fp8_recipe=fp8_recipe, model="126m", skip_wgrad=False, - zero_centered_gamma=False, normalization="LayerNorm", ) @@ -910,13 +735,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -926,7 +745,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -942,18 +761,11 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("model", ["small"]) -@pytest.mark.parametrize("skip_wgrad", all_boolean) -def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): +def test_sanity_drop_path(dtype, fp8_recipe, model): config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -963,7 +775,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -976,7 +788,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): device="cuda", ) - _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) + _test_sanity_e2e(block, dtype, config, fp8_recipe, False) @pytest.mark.parametrize("dtype", param_types) @@ -987,13 +799,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -1003,7 +809,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -1016,27 +822,18 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): device="cuda", ) - _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) + _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad) @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) -@pytest.mark.parametrize("zero_centered_gamma", all_boolean) -def test_sanity_gradient_accumulation_fusion( - dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma -): +def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad): config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -1046,7 +843,7 @@ def test_sanity_gradient_accumulation_fusion( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -1055,7 +852,6 @@ def test_sanity_gradient_accumulation_fusion( params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, - zero_centered_gamma=zero_centered_gamma, fuse_qkv_params=True, fuse_wgrad_accumulation=True, device="cuda", @@ -1064,53 +860,6 @@ def test_sanity_gradient_accumulation_fusion( _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad) -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("model", ["small"]) -@pytest.mark.parametrize("skip_wgrad", all_boolean) -@pytest.mark.parametrize("zero_centered_gamma", all_boolean) -@pytest.mark.parametrize("normalization", all_normalizations) -def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization): - - config = model_configs[model] - - if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if fp8_recipe.float8_block_scaling(): - pytest.skip("cuda graph not supported for float8_block_scaling recipe") - if not config.is_fp8_supported(): - pytest.skip("Model config does not support FP8") - - sigma = 0.023 - init_method = init_method_normal(sigma) - output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - - block = TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - params_dtype=dtype, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - zero_centered_gamma=zero_centered_gamma, - fuse_qkv_params=True, - normalization=normalization, - device="cuda", - ) - - _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad) - - def test_model_multiple_cast(): a = torch.zeros((16, 16), device="cuda") m = Linear(16, 32) @@ -1165,134 +914,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): torch.cuda.synchronize() -#TODO: rocm fused_attn backends does not support fp8 yet -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.") -@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") -@pytest.mark.parametrize("model", ["large"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_sanity_attention_extra_state(model, dtype): - config = model_configs[model] - outputs = _run_attention_extra_state(dtype, config, checkpoint=False) - outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) - outputs_checkpoint_v1_6 = _run_attention_extra_state( - dtype, config, mimic_v1_6=True, checkpoint=True - ) - - # Check that results match - tols = dtype_tols(dtype) - if dtype in (torch.float16, torch.bfloat16): - tols.update(dict(rtol=2e-2, atol=2e-3)) - for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)): - torch.testing.assert_close( - test, - ref, - **tols, - ) - for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)): - torch.testing.assert_close( - test, - ref, - **tols, - ) - - -def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): - steps = 10 - path = "checkpoint.pt" - fp8_enabled = True - fp8_recipe = recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - fp8_dpa=fp8_enabled, - fp8_mha=False, - ) - - reset_rng_states() - hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), - dtype=dtype, - device="cuda", - requires_grad=True, - ) - - def get_model(dtype, config): - sigma = 0.023 - init_method = init_method_normal(sigma) - output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - - with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe): - block = TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.0, - attention_dropout=0.0, - fuse_qkv_params=True, - params_dtype=dtype, - device="cuda", - ) - return block - - block = get_model(dtype, config) - for i in range(steps // 2): - with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): - output = block(hidden_states, None) - loss = output.sum() - loss.backward() - - if checkpoint: - sd = block.state_dict() - if mimic_v1_6: - sd["self_attention.core_attention.fused_attention._extra_state"] = sd[ - "self_attention.core_attention._extra_state" - ] - del sd["self_attention.core_attention._extra_state"] - torch.save(sd, path) - - param_grads = [] - for p in block.parameters(): - if p.requires_grad: - param_grads.append(p.grad.clone()) - - _cpu_rng_state_new = torch.get_rng_state() - _cuda_rng_state_new = torch.cuda.get_rng_state() - - del block - block = get_model(dtype, config) - block.load_state_dict(torch.load(path, weights_only=False)) - torch.set_rng_state(_cpu_rng_state_new) - torch.cuda.set_rng_state(_cuda_rng_state_new) - - for p in block.parameters(): - if p.requires_grad: - p.grad = param_grads.pop(0) - - assert not param_grads, "Oops!" - - for i in range((steps + 1) // 2): - with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): - output = block(hidden_states, None) - loss = output.sum() - loss.backward() - - torch.cuda.synchronize() - - if os.path.exists(path): - os.remove(path) - - outputs = [output, hidden_states.grad] - for p in block.parameters(): - if p.requires_grad: - outputs.append(p.grad) - - return outputs - - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_replace_raw_data_for_float8tensor(): """Test the functionality of replace_raw_data""" @@ -1388,6 +1009,32 @@ def backward(ctx, grad_output): torch.testing.assert_close(grad_checkpoint, grad_standard) +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_linear_frozen_weights_memory_default_recipe(): + """Test that memory usage is optimized when weights are frozen for MXFP8.""" + dim = 1024 + linear = Linear(dim, dim, bias=False) + x = torch.randn(dim, dim, requires_grad=True, device="cuda") + + # Freeze weights + linear.weight.requires_grad = False + + # Forward and backward pass with FP8 + with fp8_autocast(): + o = linear(x) + g_o = torch.randn_like(o) + + max_memory_before_backward = torch.cuda.max_memory_allocated() + o.backward(g_o) + max_memory_after_backward = torch.cuda.max_memory_allocated() + + memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6 + assert memory_diff < 5.5, ( + f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the" + " grad_output should be quantized only columnwise." + ) + + @pytest.mark.parametrize( "module_name", ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"), diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 0c50592bd..684c15737 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -1,18 +1,31 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. from __future__ import annotations +import logging +import os +from contextlib import contextmanager + +import pytest import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine import transformer_engine.common.recipe import transformer_engine.pytorch as te from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type import transformer_engine_torch as tex +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from transformer_engine.pytorch.attention.dot_product_attention.utils import ( + get_attention_backend, + AttentionParams, + AttentionLogging, +) +from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend torch_float8_e4m3_type = get_torch_float8_e4m3_type() torch_float8_e5m2_type = get_torch_float8_e5m2_type() @@ -111,3 +124,214 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling() raise ValueError(f"Unsupported quantization scheme ({name})") + + +# Cached RNG state +_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + +def reset_rng_states() -> None: + """Revert to deterministic RNG state""" + global _rng_states + if _rng_states is None: + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + _rng_states = (torch.get_rng_state(), torch.cuda.get_rng_state()) + else: + cpu_rng_state, cuda_rng_state = _rng_states + torch.set_rng_state(cpu_rng_state) + torch.cuda.set_rng_state(cuda_rng_state) + + +class ModelConfig: + def __init__( + self, + batch_size: int, + max_seqlen_q: int, + num_heads: int, + head_dim_qk: int, + max_seqlen_kv: int = None, + num_gqa_groups: int = None, + head_dim_v: int = None, + dropout_p: float = 0.0, + attn_mask_type: str = "no_mask", + attn_bias_type: str = "no_bias", + alibi_type: str = "none", + bias_shape: str = "1hss", + window_size: Tuple[int, int] = (-1, -1), + total_requests: int = None, + max_ctx_len: int = None, + num_layers: int = 1, + eps: float = 1e-5, + ): + self.batch_size = batch_size + self.max_seqlen_q = max_seqlen_q + self.max_seqlen_kv = max_seqlen_q if max_seqlen_kv is None else max_seqlen_kv + self.num_heads = num_heads + self.num_gqa_groups = num_heads if num_gqa_groups is None else num_gqa_groups + self.head_dim_qk = head_dim_qk + self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v + if self.head_dim_qk == self.head_dim_v: + self.kv_channels = self.head_dim_qk + else: + self.kv_channels = (self.head_dim_qk, self.head_dim_v) + self.hidden_size = self.num_heads * self.head_dim_qk + self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v + self.dropout_p = dropout_p + self.attn_mask_type = attn_mask_type + self.attn_bias_type = attn_bias_type + self.alibi_type = alibi_type + self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross" + self.bias_shape = bias_shape + self.window_size = window_size + self.total_requests = total_requests + self.max_ctx_len = max_ctx_len + self.num_layers = num_layers + self.eps = eps + + +@contextmanager +def logging_context(highest_level=logging.WARNING): + previous_level = logging.root.manager.disable + logging.disable(highest_level) + try: + yield + finally: + logging.disable(previous_level) + + +if IS_HIP_EXTENSION: + class EnvVarCleaner: + def __init__(self, envs_): + self.envs = envs_ + self.flags = {} + for env in self.envs: + if env in os.environ: + self.flags[env] = os.environ[env] + + def __del__(self): + for env in self.envs: + if env in self.flags: + os.environ[env] = self.flags[env] + else: + os.environ.pop(env, None) + + +def get_available_attention_backends( + config: ModelConfig, + qkv_dtype: torch.dtype, + qkv_layout: str, + window_size: Tuple[int, int] = (-1, -1), + pad_between_seqs: bool = False, + context_parallel: bool = False, + deterministic: bool = False, + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + is_training: bool = True, + inference_params: Optional[InferenceParams] = None, +) -> Tuple[List, List]: + """Check for all available attention backends that support a model configuration""" + if IS_HIP_EXTENSION: + env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", + "NVTE_FUSED_ATTN_AOTRITON", "NVTE_FUSED_ATTN_CK"]) + + os.environ["NVTE_FLASH_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "1" + _attention_backends["backend_selection_requires_update"] = True + + alibi_slopes_shape = None + if config.attn_bias_type == "alibi" and config.alibi_type == "custom": + if config.bias_shape == "1hss": + alibi_slopes_shape = [config.num_heads] + if config.bias_shape == "bhss": + alibi_slopes_shape = [config.batch_size, config.num_heads] + + core_attention_bias_shape = ( + config.bias_shape if config.attn_bias_type == "post_scale_bias" else None + ) + core_attention_bias_requires_grad = False + # d=256 is supported by cuDNN 9.0+ for inference but not training + if ( + config.attn_bias_type == "post_scale_bias" + and config.head_dim_qk <= 128 + and config.head_dim_v <= 128 + ): + core_attention_bias_requires_grad = True + + fused_attn_backends = [] + available_backends = None + flash_attention_backend = None + fused_attention_backend = None + + def test(): + attention_params = AttentionParams( + qkv_dtype=qkv_dtype, + qkv_layout=qkv_layout, + batch_size=config.batch_size, + num_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + max_seqlen_q=config.max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + head_dim_qk=config.head_dim_qk, + head_dim_v=config.head_dim_v, + attn_mask_type=config.attn_mask_type, + window_size=window_size, + alibi_slopes_shape=alibi_slopes_shape, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias_shape=core_attention_bias_shape, + core_attention_bias_requires_grad=core_attention_bias_requires_grad, + pad_between_seqs=pad_between_seqs, + attention_dropout=config.dropout_p, + context_parallel=context_parallel, + deterministic=deterministic, + fp8=fp8, + fp8_meta=fp8_meta, + is_training=is_training, + inference_params=inference_params, + ) + ( + use_flash_attention, + flash_attention_backend, + use_fused_attention, + fused_attention_backend, + use_unfused_attention, + available_backends, + ) = get_attention_backend(attention_params) + # Set attention.py _attention_backends var using return value + # from get_attention_backend() + _attention_backends["use_flash_attention"] = use_flash_attention + _attention_backends["use_fused_attention"] = use_fused_attention + _attention_backends["flash_attention_backend"] = flash_attention_backend + _attention_backends["fused_attention_backend"] = fused_attention_backend + _attention_backends["use_unfused_attention"] = use_unfused_attention + _attention_backends["backend_selection_requires_update"] = False + return available_backends, flash_attention_backend, fused_attention_backend + + if IS_HIP_EXTENSION: + backends = {"AOTriton": "AOTRITON", "CK": "CK"} + if AttentionLogging._is_logging_setup is False: + AttentionLogging.setup_logging() + with logging_context(highest_level=AttentionLogging._log_level): + for i in backends.keys(): + for k in backends.keys(): + os.environ["NVTE_FUSED_ATTN_"+backends[k]] = "0" + os.environ["NVTE_FUSED_ATTN_"+backends[i]] = "1" + _attention_backends["backend_selection_requires_update"] = True + available_backends, flash_attention_backend, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[i]: + fused_attn_backends.append(fused_attention_backend) + available_backends[1] = len(fused_attn_backends) > 0 + return available_backends, flash_attention_backend, fused_attn_backends + + backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} + if AttentionLogging._is_logging_setup is False: + AttentionLogging.setup_logging() + with logging_context(highest_level=AttentionLogging._log_level): + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, flash_attention_backend, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) + return available_backends, flash_attention_backend, fused_attn_backends diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index e4f61d75a..cefec6d06 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -74,6 +74,12 @@ if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") "within the Transformer Engine source.") endif() include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) + +set(CUTLASS_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/include") +set(CUTLASS_TOOLS_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include") + else() set(CMAKE_CXX_STANDARD 17) project(transformer_engine LANGUAGES HIP CXX) @@ -131,7 +137,9 @@ list(APPEND transformer_engine_SOURCES transpose/cast_transpose_fusion.cu transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu + transpose/swap_first_dims.cu activation/gelu.cu + dropout/dropout.cu fused_attn/flash_attn.cu fused_attn/context_parallel.cu fused_attn/kv_cache.cu @@ -175,11 +183,18 @@ list(APPEND transformer_engine_SOURCES fused_attn/fused_attn_fp8.cu fused_attn/fused_attn.cpp fused_attn/utils.cu + gemm/cutlass_grouped_gemm.cu util/cuda_nvml.cpp comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers.cu comm_gemm_overlap/comm_gemm_overlap.cpp) + +if (NVTE_WITH_CUBLASMP) +list(APPEND transformer_engine_SOURCES + comm_gemm/comm_gemm.cpp) +endif() + add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) else() list(APPEND transformer_engine_SOURCES @@ -222,10 +237,21 @@ else() add_library(transformer_engine SHARED ${te_hip_sources}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}") endif() + target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") - +if (USE_CUDA) +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + set_source_files_properties( + "gemm/cutlass_grouped_gemm.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") +else() + message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") +endif() +endif() #USE_CUDA # Configure dependencies if (USE_CUDA) @@ -233,9 +259,15 @@ target_link_libraries(transformer_engine PUBLIC CUDA::cublas CUDA::cudart CUDNN::cudnn_all) + target_include_directories(transformer_engine PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(transformer_engine SYSTEM PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") +target_include_directories(transformer_engine PRIVATE + ${CUTLASS_INCLUDE_DIR} + ${CUTLASS_TOOLS_INCLUDE_DIR}) # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) @@ -253,6 +285,25 @@ if (NVTE_ENABLE_NVSHMEM) target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR}) endif() +option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) +if (NVTE_WITH_CUBLASMP) + target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) + target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) + find_library(CUBLASMP_LIB + NAMES cublasmp libcublasmp + PATHS ${CUBLASMP_DIR} + PATH_SUFFIXES lib + REQUIRED) + find_library(NVSHMEM_HOST_LIB + NAMES nvshmem_host libnvshmem_host.so.3 + PATHS ${NVSHMEM_DIR} + PATH_SUFFIXES lib + REQUIRED) + target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) + message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") + message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") +endif() + # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) @@ -332,6 +383,8 @@ make_string_header_from_file(transpose/rtc/cast_transpose.cu string_code_transpose_rtc_cast_transpose_cu) make_string_header_from_file(transpose/rtc/transpose.cu string_code_transpose_rtc_transpose_cu) +make_string_header_from_file(transpose/rtc/swap_first_dims.cu + string_code_transpose_rtc_swap_first_dims_cu) make_string_header_from_file(utils.cuh string_code_utils_cuh) else() @@ -343,6 +396,8 @@ else() string_code_transpose_rtc_cast_transpose_cu) make_string_header_from_file(transpose/rtc/transpose.hip string_code_transpose_rtc_transpose_cu) + make_string_header_from_file(transpose/rtc/swap_first_dims.hip + string_code_transpose_rtc_swap_first_dims_cu) make_string_header_from_file(amd_detail/hip_float8.h string_code_amd_detail_hip_float8_h) endif() @@ -371,6 +426,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) set_source_files_properties(activation/gelu.cu activation/relu.cu activation/swiglu.cu + util/cast.cu PROPERTIES COMPILE_OPTIONS "--use_fast_math") endif() diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 8a73138e3..26672bafd 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -223,6 +223,11 @@ def _nvidia_cudart_include_dir() -> str: except ModuleNotFoundError: return "" + # Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia" + # above doesn't through. However, they don't set "__file__" attribute. + if nvidia.__file__ is None: + return "" + include_dir = Path(nvidia.__file__).parent / "cuda_runtime" return str(include_dir) if include_dir.exists() else "" @@ -251,6 +256,18 @@ def _load_cudnn(): if found: return handle + # Attempt to locate libcudnn via ldconfig + libs = subprocess.check_output( + f"ldconfig -p | grep 'libcudnn{_get_sys_extension()}'", shell=True + ) + libs = libs.decode("utf-8").split("\n") + sos = [] + for lib in libs: + if "libcudnn" in lib and "=>" in lib: + sos.append(lib.split(">")[1].strip()) + if sos: + return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) + # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) @@ -272,12 +289,12 @@ def _load_nvrtc(): return handle # Attempt to locate NVRTC via ldconfig - libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True) + libs = subprocess.check_output( + f"ldconfig -p | grep 'libnvrtc{_get_sys_extension()}'", shell=True + ) libs = libs.decode("utf-8").split("\n") sos = [] for lib in libs: - if "stub" in lib or "libnvrtc-builtins" in lib: - continue if "libnvrtc" in lib and "=>" in lib: sos.append(lib.split(">")[1].strip()) if sos: @@ -286,6 +303,38 @@ def _load_nvrtc(): # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + +@functools.lru_cache(maxsize=None) +def _load_curand(): + """Load cuRAND shared library.""" + # Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" + libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True) + libs = list(filter(lambda x: not ("stub" in x), libs)) + libs.sort(reverse=True, key=os.path.basename) + if libs: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + + # Attempt to locate cuRAND in Python dist-packages + found, handle = _load_nvidia_cuda_library("curand") + if found: + return handle + + # Attempt to locate cuRAND via ldconfig + libs = subprocess.check_output( + f"ldconfig -p | grep 'libcurand{_get_sys_extension()}'", shell=True + ) + libs = libs.decode("utf-8").split("\n") + sos = [] + for lib in libs: + if "libcurand" in lib and "=>" in lib: + sos.append(lib.split(">")[1].strip()) + if sos: + return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) + + # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise + return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + te_rocm_build = False @functools.cache @@ -305,6 +354,7 @@ def _load_core_library(): try: _CUDNN_LIB_CTYPES = _load_cudnn() _NVRTC_LIB_CTYPES = _load_nvrtc() + _CURAND_LIB_CTYPES = _load_curand() _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp new file mode 100644 index 000000000..76f46298d --- /dev/null +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -0,0 +1,519 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/comm_gemm.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" + +using namespace transformer_engine; + +namespace { + +// TODO: log warnings on failures of the *Destroy calls below, once TE has such ability. +// For now, just silently ignoring the errors, since the only diag available in TE is throwing +// exceptions, but these calls will typically be made from destructors, so cannot throw. + +template +auto CreateWithCudaCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) { + using Handle = std::remove_pointer_t; + HandlePtr raw{}; + NVTE_CHECK_CUDA(create_fn(&raw, std::forward(args)...)); + return std::unique_ptr(raw, destroy_fn); +} + +using CudaStream = + std::unique_ptr, decltype(&cudaStreamDestroy)>; + +CudaStream CudaStreamCreate() { + return CreateWithCudaCheck(cudaStreamCreate, cudaStreamDestroy); +} + +using CudaEvent = std::unique_ptr, decltype(&cudaEventDestroy)>; + +CudaEvent CudaEventCreate(unsigned flags) { + return CreateWithCudaCheck(cudaEventCreateWithFlags, cudaEventDestroy, flags); +} + +template +auto CreateWithCublasMpCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) { + using Handle = std::remove_pointer_t; + HandlePtr raw{}; + if constexpr (raw_last) { + NVTE_CHECK_CUBLASMP(create_fn(std::forward(args)..., &raw)); + } else { + NVTE_CHECK_CUBLASMP(create_fn(&raw, std::forward(args)...)); + } + return std::unique_ptr(raw, destroy_fn); +} + +using CublasMp = + std::unique_ptr, decltype(&cublasMpDestroy)>; + +CublasMp CublasMpCreate(cudaStream_t stream) { + return CreateWithCublasMpCheck(cublasMpCreate, cublasMpDestroy, stream); +} + +using CublasMpGrid = + std::unique_ptr, decltype(&cublasMpGridDestroy)>; + +CublasMpGrid CublasMpGridCreate(int64_t nprow, int64_t npcol, cublasMpGridLayout_t layout, + ncclComm_t comm) { + return CreateWithCublasMpCheck(cublasMpGridCreate, cublasMpGridDestroy, + nprow, npcol, layout, comm); +} + +using CublasMpMatrixDesc = std::unique_ptr, + decltype(&cublasMpMatrixDescriptorDestroy)>; + +CublasMpMatrixDesc CublasMpMatrixDescCreate(int64_t m, int64_t n, int64_t mb, int64_t nb, + int64_t rsrc, int64_t csrc, int64_t lld, + cudaDataType_t type, cublasMpGrid_t grid) { + return CreateWithCublasMpCheck( + cublasMpMatrixDescriptorCreate, cublasMpMatrixDescriptorDestroy, m, n, mb, nb, rsrc, csrc, + lld, type, grid); +} + +using CublasMpMatmulDesc = std::unique_ptr, + decltype(&cublasMpMatmulDescriptorDestroy)>; + +CublasMpMatmulDesc CublasMpMatmulDescCreate(cublasComputeType_t compute_type) { + return CreateWithCublasMpCheck( + cublasMpMatmulDescriptorCreate, cublasMpMatmulDescriptorDestroy, compute_type); +} + +} // namespace + +struct NVTECommGemmCtx { + int64_t nranks; + int64_t rank; + ncclComm_t comm; + CudaStream stream; + CudaEvent event; + CublasMp cublas_mp; + CublasMpGrid grid_col_major; + CublasMpGrid grid_row_major; + CublasMpMatrixDesc a_desc; + CublasMpMatrixDesc b_desc; + CublasMpMatrixDesc d_desc; + CublasMpMatmulDesc matmul_desc; + void* workspace; + size_t workspace_size; +}; + +namespace { + +int64_t block_size(NVTECommGemmCtx* ctx, int64_t global_size) { + // Use non-cyclic layout to maximize opportunity for comm overlap. + return (global_size + ctx->nranks - 1) / ctx->nranks; +} + +void AgGemmInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, + const Tensor* a, const Tensor* b, const Tensor* d, bool transa, + bool transb) { + const auto a0 = a->flat_first_dim(); + const auto a1 = a->flat_last_dim(); + const auto b0 = b->flat_first_dim(); + const auto b1 = b->flat_last_dim(); + const auto d0 = d->flat_first_dim(); + const auto d1 = d->flat_last_dim(); + + if (transa) { + NVTE_CHECK(a1 == k, "Unsupported tensor dimension in A: expected ", k, ", got ", a1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, k, block_size(ctx, m), 0, 0, k, + get_cuda_dtype(a->dtype()), + ctx->grid_row_major.get(), ctx->a_desc.get())); + } else { + NVTE_CHECK(a0 == k, "Unsupported tensor dimension in A: expected ", k, ", got ", a0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, k, block_size(ctx, m), k, 0, 0, + block_size(ctx, m), get_cuda_dtype(a->dtype()), + ctx->grid_col_major.get(), ctx->a_desc.get())); + } + if (transb) { + NVTE_CHECK(b0 == k, "Unsupported tensor dimensionin B: expected ", k, ", got ", b0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(n, k, block_size(ctx, n), k, 0, 0, + block_size(ctx, n), get_cuda_dtype(b->dtype()), + ctx->grid_col_major.get(), ctx->b_desc.get())); + } else { + NVTE_CHECK(b1 == k, "Unsupported tensor dimension in B: expected ", k, ", got ", b1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, n, k, block_size(ctx, n), 0, 0, k, + get_cuda_dtype(b->dtype()), + ctx->grid_row_major.get(), ctx->b_desc.get())); + } + NVTE_CHECK(d0 == n, "Unsupported tensor dimension in D: expected ", n, ", got ", d0); + *ldd = block_size(ctx, m); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n, block_size(ctx, m), block_size(ctx, n), 0, + 0, *ldd, get_cuda_dtype(d->dtype()), + ctx->grid_col_major.get(), ctx->d_desc.get())); +} + +void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, + const Tensor* a, const Tensor* b, const Tensor* d, bool transa, + bool transb) { + const auto a0 = a->flat_first_dim(); + const auto a1 = a->flat_last_dim(); + const auto b0 = b->flat_first_dim(); + const auto b1 = b->flat_last_dim(); + const auto d0 = d->flat_first_dim(); + const auto d1 = d->flat_last_dim(); + + if (transa) { + NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, block_size(ctx, k), m, 0, 0, + block_size(ctx, k), get_cuda_dtype(a->dtype()), + ctx->grid_col_major.get(), ctx->a_desc.get())); + } else { + NVTE_CHECK(a1 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, k, m, block_size(ctx, k), 0, 0, m, + get_cuda_dtype(a->dtype()), + ctx->grid_row_major.get(), ctx->a_desc.get())); + } + if (transb) { + NVTE_CHECK(b1 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit( + n, k, block_size(ctx, n), block_size(ctx, k), 0, 0, block_size(ctx, n), + get_cuda_dtype(b->dtype()), ctx->grid_row_major.get(), ctx->b_desc.get())); + } else { + NVTE_CHECK(b0 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit( + k, n, block_size(ctx, k), block_size(ctx, n), 0, 0, block_size(ctx, k), + get_cuda_dtype(b->dtype()), ctx->grid_col_major.get(), ctx->b_desc.get())); + } + NVTE_CHECK(d1 == m, "Unsupported tensor dimension in D: expected ", m, ", got ", d1); + *ldd = m; + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n, m, block_size(ctx, n), 0, 0, *ldd, + get_cuda_dtype(d->dtype()), + ctx->grid_row_major.get(), ctx->d_desc.get())); +} + +void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, + const Tensor* a, const Tensor* b, const Tensor* d, bool transa, + bool transb) { + const auto a0 = a->flat_first_dim(); + const auto a1 = a->flat_last_dim(); + const auto b0 = b->flat_first_dim(); + const auto b1 = b->flat_last_dim(); + const auto d0 = d->flat_first_dim(); + const auto d1 = d->flat_last_dim(); + + if (transa) { + NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, block_size(ctx, k), m, 0, 0, + block_size(ctx, k), get_cuda_dtype(a->dtype()), + ctx->grid_col_major.get(), ctx->a_desc.get())); + } else { + NVTE_ERROR("N transpose flag is not supported for input A"); + } + if (transb) { + NVTE_ERROR("T transpose flag is not supported for input B"); + } else { + NVTE_CHECK(b0 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, n, block_size(ctx, k), n, 0, 0, + block_size(ctx, k), get_cuda_dtype(b->dtype()), + ctx->grid_col_major.get(), ctx->b_desc.get())); + } + NVTE_CHECK(d1 == m, "Unsupported tensor dimension in D: expected ", m, ", got ", d1); + *ldd = m; + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n * ctx->nranks, m, n, 0, 0, *ldd, + get_cuda_dtype(d->dtype()), + ctx->grid_row_major.get(), ctx->d_desc.get())); + + const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE; + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, + sizeof epilogue)); +} + +using InitMatricesFn = void (*)(NVTECommGemmCtx*, int64_t*, int64_t, int64_t, int64_t, + const Tensor*, const Tensor*, const Tensor*, bool, bool); + +cublasMpMatmulAlgoType_t cublasmp_algo(NVTECommGemmAlgoType algo) { + static const std::unordered_map s_map{ + {kNVTECommGemmAlgoDefault, CUBLASMP_MATMUL_ALGO_TYPE_DEFAULT}, + {kNVTECommGemmAlgoSplitP2P, CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_P2P}, + {kNVTECommGemmAlgoSplitMulticast, CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_MULTICAST}, + {kNVTECommGemmAlgoAtomicP2P, CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_P2P}, + {kNVTECommGemmAlgoAtomicMulticast, CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_MULTICAST}, + }; + auto it = s_map.find(algo); + return it != s_map.end() ? it->second : static_cast(algo); +} + +void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECommGemmAlgoType algo, + int64_t m, int64_t n, int64_t k, const Tensor* a, const Tensor* b, + const Tensor* d, const Tensor* bias, const Tensor* pre_act_out, bool transa, + bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t main_stream) { + for (auto t : {a, b, d}) { + NVTE_CHECK(is_tensor_scaling(t->scaling_mode), + "Unsupported scaling mode: " + std::to_string(t->scaling_mode)); + } + + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorInit(ctx->matmul_desc.get(), CUBLAS_COMPUTE_32F)); + + int64_t ldd{}; + init_matrices_fn(ctx, &ldd, m, n, k, a, b, d, transa, transb); + + const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N; + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a, + sizeof trans_a)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b, + sizeof trans_b)); + cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr, + sizeof algo_attr)); + + const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32; + if (is_fp8_dtype(a->dtype())) { + NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode, + sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER, + &a->scale_inv.dptr, sizeof(void*))); + } + if (is_fp8_dtype(b->dtype())) { + NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode, + sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER, + &b->scale_inv.dptr, sizeof(void*))); + } + if (is_fp8_dtype(d->dtype())) { + NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode, + sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER, + &d->scale.dptr, sizeof(void*))); + if (d->amax.dptr) { + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER, + &d->amax.dptr, sizeof(void*))); + } + } + + // Might be set to ALLREDUCE before, need to OR with the new flags to set. + cublasMpMatmulEpilogue_t epilogue{}; + size_t size_read{}; + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, + sizeof epilogue, &size_read)); + NVTE_CHECK(size_read == sizeof epilogue); + // (bias, gelu, grad) -> epilogue + const std::map, cublasMpMatmulEpilogue_t> flags_to_epilogue{ + {{true, true, false}, CUBLASMP_MATMUL_EPILOGUE_GELU_AUX_BIAS}, + {{true, true, true}, CUBLASMP_MATMUL_EPILOGUE_DGELU_BGRAD}, + {{true, false, false}, CUBLASMP_MATMUL_EPILOGUE_BIAS}, + {{true, false, true}, CUBLASMP_MATMUL_EPILOGUE_BGRADB}, + {{false, true, false}, CUBLASMP_MATMUL_EPILOGUE_GELU_AUX}, + {{false, true, true}, CUBLASMP_MATMUL_EPILOGUE_DGELU}, + }; + if (auto it = + flags_to_epilogue.find({bias ? bias->data.dptr != nullptr : false, + pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad}); + it != flags_to_epilogue.end()) { + epilogue = static_cast(epilogue | it->second); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, + sizeof epilogue)); + } + + if (bias && bias->data.dptr) { + cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type, + sizeof bias_type)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr, + sizeof bias->data.dptr)); + } + + if (pre_act_out && pre_act_out->data.dptr) { + cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE, + &aux_type, sizeof aux_type)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER, + &pre_act_out->data.dptr, sizeof pre_act_out->data.dptr)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd, + sizeof ldd)); + if (is_fp8_dtype(pre_act_out->dtype())) { + NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE, + &scale_mode, sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER, + &pre_act_out->scale.dptr, sizeof(void*))); + if (pre_act_out->amax.dptr) { + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER, + &pre_act_out->amax.dptr, sizeof(void*))); + } + } + } + + if (comm_sm_count) { + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT, + &comm_sm_count, sizeof comm_sm_count)); + } + + NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream)); + + size_t wrksp_size_device{}; + size_t wrksp_size_host{}; + + float alpha = 1.0; + float beta = accumulate ? 1.0 : 0.0; + std::tuple args{ctx->cublas_mp.get(), + ctx->matmul_desc.get(), + m, + n, + k, + &alpha, + a->data.dptr, + 1, + 1, + ctx->a_desc.get(), + b->data.dptr, + 1, + 1, + ctx->b_desc.get(), + &beta, + accumulate ? d->data.dptr : nullptr, + 1, + 1, + accumulate ? ctx->d_desc.get() : nullptr, + d->data.dptr, + 1, + 1, + ctx->d_desc.get()}; + NVTE_CHECK_CUBLASMP( + std::apply(cublasMpMatmul_bufferSize, + std::tuple_cat(args, std::tuple{&wrksp_size_device, &wrksp_size_host}))); + + std::vector workspace_host(wrksp_size_host); + if (ctx->workspace_size < wrksp_size_device) { + nvshmem_free(ctx->workspace); + ctx->workspace = nvshmem_malloc(wrksp_size_device); + ctx->workspace_size = wrksp_size_device; + } + + NVTE_CHECK_CUBLASMP( + std::apply(cublasMpMatmul, + std::tuple_cat(args, std::tuple{ctx->workspace, ctx->workspace_size, + workspace_host.data(), workspace_host.size()}))); + + NVTE_CHECK_CUDA(cudaEventRecord(ctx->event.get(), main_stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(ctx->stream.get(), ctx->event.get(), 0)); +} + +} // namespace + +NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank) { + NVTE_API_CALL(nvte_comm_gemm_ctx_create); + auto stream = CudaStreamCreate(); + auto event = CudaEventCreate(cudaEventDisableTiming); + auto cublas_mp = CublasMpCreate(stream.get()); + + auto col_major = CublasMpGridCreate(nranks, 1, CUBLASMP_GRID_LAYOUT_COL_MAJOR, comm); + auto row_major = CublasMpGridCreate(1, nranks, CUBLASMP_GRID_LAYOUT_ROW_MAJOR, comm); + + // Pre-creating matrix descriptors here, will be initialized with the actual params later. + auto a_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get()); + auto b_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get()); + auto d_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get()); + + auto matmul_desc = CublasMpMatmulDescCreate(CUBLAS_COMPUTE_32F); + + return new NVTECommGemmCtx{ + .nranks = nranks, + .rank = rank, + .comm = comm, + .stream = std::move(stream), + .event = std::move(event), + .cublas_mp = std::move(cublas_mp), + .grid_col_major = std::move(col_major), + .grid_row_major = std::move(row_major), + .a_desc = std::move(a_desc), + .b_desc = std::move(b_desc), + .d_desc = std::move(d_desc), + .matmul_desc = std::move(matmul_desc), + }; +} + +void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { + NVTE_API_CALL(nvte_comm_gemm_ctx_destroy); + nvshmemx_sync_all_on_stream(ctx->stream.get()); + delete ctx; +} + +void nvte_all_gather_gemm(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo) { + NVTE_API_CALL(nvte_all_gather_gemm); + cublasmp_gemm(AgGemmInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a), + convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias), + convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate, + comm_sm_count, main_stream); +} + +void nvte_gemm_reduce_scatter(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, + const NVTETensor a, const NVTETensor b, const NVTETensor d, + const NVTETensor bias, const NVTETensor pre_act_out, bool transa, + bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t main_stream, NVTECommGemmAlgoType algo) { + NVTE_API_CALL(nvte_gemm_reduce_scatter); + cublasmp_gemm(GemmRsInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a), + convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias), + convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate, + comm_sm_count, main_stream); +} + +void nvte_gemm_all_reduce(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo) { + NVTE_API_CALL(nvte_gemm_all_reduce); + cublasmp_gemm(GemmArInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a), + convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias), + convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate, + comm_sm_count, main_stream); +} + +int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size) { + NVTE_API_CALL(nvte_comm_gemm_numroc); + return cublasMpNumroc(global_size, block_size(ctx, global_size), ctx->rank, 0, ctx->nranks); +} diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 40595ea98..ec29e6e12 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -101,10 +101,10 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl DType::kInt32); } // CUDA event creation - cudaEventCreateWithFlags(&_start_compute, 0); - cudaEventCreateWithFlags(&_stop_compute, 0); - cudaEventCreateWithFlags(&_start_comm, 0); - cudaEventCreateWithFlags(&_stop_comm, 0); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_compute, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_compute, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_comm, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_comm, 0)); /* Defining the launcher order between the communication and GEMM kernels @@ -114,11 +114,11 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl */ int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); int runtime_version = 0; - cudaRuntimeGetVersion(&runtime_version); + NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&runtime_version)); cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, 0)); if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) { - cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming)); } else { _comm_launch_event = 0; } @@ -129,18 +129,34 @@ CommOverlapCore::~CommOverlapCore() { cudaEventDestroy(_start_comm); cudaEventDestroy(_stop_compute); cudaEventDestroy(_start_compute); - if (_comm_launch_event) cudaEventDestroy(_comm_launch_event); + if (_comm_launch_event) { + cudaEventDestroy(_comm_launch_event); + } + + if (_atomic_gemm) { + cudaFree(_counter.dptr()); + } - if (_atomic_gemm) cudaFree(_counter.dptr()); + for (size_t i = 0; i < _stream_compute.size(); i++) { + cudaStreamSynchronize(_stream_compute[i]); + cudaStreamDestroy(_stream_compute[i]); + } - for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); + auto error = cudaGetLastError(); + if (error != cudaSuccess) { + NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error)); + } if (_comm_created) { + try { #ifdef NVTE_UB_WITH_MPI - destroy_communicator_mpi(_ub_comm); + destroy_communicator_mpi(_ub_comm); #else - destroy_communicator(_ub_comm); + destroy_communicator(_ub_comm); #endif + } catch (const std::exception &e) { + NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what()); + } _comm_created = false; } } @@ -282,6 +298,7 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType CommOverlapBase::~CommOverlapBase() { cudaEventDestroy(_start_d2dcopy); + cudaStreamSynchronize(_stream_comm); cudaStreamDestroy(_stream_comm); } @@ -584,6 +601,29 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); } // CommOverlapBase::split_overlap_rs +void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) { + int comm_bytes = _ubuf.bytes(); + int comm_bytes_per_rank = comm_bytes / _tp_size; + + // We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush + userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, + _ub_comm, send_stream); + userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, + _ub_comm, recv_stream); + + // We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf + for (auto stream : {send_stream, recv_stream}) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _stop_comm, 0)); + } + + // Next we sync with the main stream + // We have to recapture an event off the comm stream to enable cuda graph capture otherwise the comm stream will be never be joined in the graph + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} + /*************************************************************************************************** * Comm+GEMM Overlap P2P Base (Ring-Exchange) **************************************************************************************************/ @@ -666,7 +706,9 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { cudaEventDestroy(_stop_recv); cudaEventDestroy(_stop_send); cudaStreamDestroy(_stream_recv); - for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]); + for (size_t i = 0; i < _stream_send.size(); i++) { + cudaStreamDestroy(_stream_send[i]); + } } TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 65da58d5f..1ce89c512 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -511,7 +511,7 @@ void destroy_communicator_mpi(communicator *comm) { } int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) { - if (comm->free_region > NVTE_MAX_REGIONS) return -1; + if (comm->free_region >= NVTE_MAX_REGIONS) return -1; int hndl = comm->free_region; comm->peer_ptr[hndl] = reinterpret_cast(malloc(sizeof(void *) * (comm->nvsize))); size_t aligned_size = bytes; diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 1211392e4..1dcd54d0d 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -2319,6 +2319,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds if (comm->push == 0) { kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]), reinterpret_cast(flagptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + srcoffset; void *dstptr = reinterpret_cast(comm->peer_ptr[dsthandler][peerlocal]) + dstoffset; @@ -2516,8 +2517,11 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast(flagptr), reinterpret_cast(srcptr), reinterpret_cast(dstptr), signalonly ? 0 : bytes / 16, comm->ub_timeout); - if (!signalonly) + NVTE_CHECK_CUDA(cudaGetLastError()); + if (!signalonly) { kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler])); + NVTE_CHECK_CUDA(cudaGetLastError()); + } if (comm->use_ce) { NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); } @@ -2532,6 +2536,33 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds reinterpret_cast(0 ? // temporary disable GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2) : nullptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); + } +} + +void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, + int tp_size, int world_rank, communicator *comm, cudaStream_t stream) { + int rank_round_tp = (world_rank / tp_size) * tp_size; + for (int j = 1; j < tp_size; j++) { + int i = (tp_rank + j) % tp_size; + int send_offset = srcoffset + bytes_per_slice * tp_rank; + int recv_offset = dstoffset + bytes_per_slice * tp_rank; + userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, + rank_round_tp + i, stream); + } +} + +void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, + int tp_size, int world_rank, communicator *comm, cudaStream_t stream) { + int rank_round_tp = (world_rank / tp_size) * tp_size; + for (int j = tp_size - 1; j > 0; j--) { + int i = (tp_rank + j) % tp_size; + int send_offset = srcoffset + bytes_per_slice * i; + int recv_offset = dstoffset + bytes_per_slice * i; + userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, + rank_round_tp + i, stream); } } @@ -2588,24 +2619,28 @@ void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 block(1); dim3 grid(1); producer_kernel<<>>(atomic_ptr, chunk_i); + NVTE_CHECK_CUDA(cudaGetLastError()); } void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 block(1); dim3 grid(1); consumer_kernel<<>>(atomic_ptr, chunk_i); + NVTE_CHECK_CUDA(cudaGetLastError()); } void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream) { dim3 block(1); dim3 grid(1); consumer_batch_kernel<<>>(atomic_ptr, first_chunk_i, num_chunks); + NVTE_CHECK_CUDA(cudaGetLastError()); } void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) { dim3 block(1); dim3 grid(1); reset_counters_kernel<<>>(atomic_ptr, num_chunks, allgather); + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -2659,6 +2694,7 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in reduce_fp8_in_bf16_out_cuda <<>>(inputs, output, scale, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size); + NVTE_CHECK_CUDA(cudaGetLastError()); } template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale, @@ -2714,4 +2750,5 @@ void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cud dim3 grid(num_blocks); reduce_bf16_cuda<<>>( inputs, output, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size); + NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 03e45b978..4d52fbb64 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -27,7 +27,7 @@ using ExtAllgatherOp = std::function; using ExtBarrierOp = std::function; -#define NVTE_MAX_REGIONS 16 +#define NVTE_MAX_REGIONS 32 #define NVTE_MAX_SMS 32 #define NVTE_MAX_OPS 32 #define NVTE_MAX_PEERS 8192 @@ -304,4 +304,12 @@ void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inp void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream); +void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, + int tp_size, int world_rank, communicator *comm, cudaStream_t stream); + +void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, + int tp_size, int world_rank, communicator *comm, cudaStream_t stream); + #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 483444751..e67694c38 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -28,12 +28,33 @@ __global__ void __launch_bounds__(1) } // namespace +#ifndef __HIP_PLATFORM_AMD__ +cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { + using namespace transformer_engine; + switch (t) { + case DType::kFloat16: + return CUDA_R_16F; + case DType::kFloat32: + return CUDA_R_32F; + case DType::kBFloat16: + return CUDA_R_16BF; + case DType::kFloat8E4M3: + return CUDA_R_8F_E4M3; + case DType::kFloat8E5M2: + return CUDA_R_8F_E5M2; + default: + NVTE_ERROR("Invalid type"); + } +} +#endif + void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) { if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) { NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv."); update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>( reinterpret_cast(t->scale.dptr), reinterpret_cast(t->scale_inv.dptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -75,6 +96,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) dim3 grid(numBlocks, 1, 1); \ memset_kernel \ <<>>(ptr, value, size_in_bytes); \ + NVTE_CHECK_CUDA(cudaGetLastError()); \ return; \ } @@ -85,7 +107,7 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream if (size_in_bytes > 4096) { // Use cudaMemsetAsync for larger sizes. - cudaMemsetAsync(ptr, value, size_in_bytes, stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(ptr, value, size_in_bytes, stream)); return; } @@ -96,7 +118,11 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream } } // extern "C" +#ifndef __HIP_PLATFORM_AMD__ void checkCuDriverContext(CUstream stream) { + // Ensure the thread's "current" CUDA context is set. + cuda_driver::ensure_context_exists(); + CUcontext ctx; const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx); switch (driver_status) { @@ -117,7 +143,6 @@ void checkCuDriverContext(CUstream stream) { } } -#ifndef __HIP_PLATFORM_AMD__ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { static const std::unordered_map dtypeMapping = []() { std::unordered_map typeMapping = { @@ -165,10 +190,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, void *dataPtr = reinterpret_cast(reinterpret_cast(tensor.dptr) + (offset_elems * type_num_bits) / 8); - NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment), + NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_GMEM_ALIGNMENT), "Tensor data pointer must be 16B aligned"); - const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits; + const int TMA_needed_size = (TMA_GMEM_ALIGNMENT * 8) / type_num_bits; NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits, "-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 39038724a..ce510334b 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -276,6 +276,9 @@ struct QuantizationConfig { }; }; +cudaDataType_t get_cuda_dtype(const transformer_engine::DType t); + +#ifndef __HIP_PLATFORM_AMD__ template constexpr T DIVUP(const T &x, const T &y) { return (((x) + ((y)-1)) / (y)); @@ -287,6 +290,22 @@ constexpr __device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(const T "Integral type required."); return DIVUP(static_cast(N), static_cast(M)) * M; } +#else +// DIVUP is called with integral types only for which passing by value is preferred. +// It also allows using of constexpr arguments w/o needing to create storage for references. +template +constexpr T DIVUP(T x, T y) { + static_assert(std::is_integral::value, "Integral type required."); + return (((x) + ((y)-1)) / (y)); +} + +template +constexpr __device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) { + static_assert(std::is_integral::value && std::is_integral::value, + "Integral type required."); + return DIVUP(static_cast(N), static_cast(M)) * M; +} +#endif //__HIP_PLATFORM_AMD__ using byte = uint8_t; using int16 = int16_t; @@ -407,9 +426,19 @@ struct BitsNumber { template struct TypeInfo { #if FP4_TYPE_SUPPORTED - using types = std::tuple; + using types = std::tuple= 12080 + , + fp8e8m0 +#endif + >; #else - using types = std::tuple; + using types = std::tuple= 12080 + , + fp8e8m0 +#endif + >; #endif template @@ -692,8 +721,11 @@ constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_Y_colwise = 4; +#ifndef __HIP_PLATFORM_AMD__ // Alignment requirements for the Tensor Memory Accelerator (TMA) -constexpr int TMA_gmem_alignment = 16; // global memory address alignment +constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment +constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment +#endif inline bool is_aligned_ptr(const void *ptr, size_t alignment) { return reinterpret_cast(ptr) % alignment == 0; @@ -724,9 +756,11 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream); #define NVTE_API_CALL(api_name) \ transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name); +#ifdef __HIP_PLATFORM_AMD__ +#define checkCuDriverContext(stream) {} +#else void checkCuDriverContext(CUstream stream); -#ifndef __HIP_PLATFORM_AMD__ CUtensorMapDataType get_CUtensorMapDataType(DType dtype); // Set up parameters to create TMA descriptor. @@ -734,7 +768,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits); -#endif //#ifndef __HIP_PLATFORM_AMD__ +#endif //#ifdef __HIP_PLATFORM_AMD__ bool is_supported_by_CC_100(); diff --git a/transformer_engine/common/dropout/dropout.cu b/transformer_engine/common/dropout/dropout.cu new file mode 100644 index 000000000..c7a5555e2 --- /dev/null +++ b/transformer_engine/common/dropout/dropout.cu @@ -0,0 +1,363 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include + +#include "../common.h" +#include "../utils.cuh" +#include "transformer_engine/dropout.h" + +namespace transformer_engine { +namespace { + +// RNG kernels process chunks of 16 entries +constexpr size_t rng_chunk_size = 16; + +// CUDA block size +constexpr size_t block_size = 128; + +// Vector class to help with vectorized memory accesses +template +union Vector { + using StorageType = typename BytesToType::Type; + StorageType storage; + T entries[kSize]; +}; + +/* Byte-wise less-than comparison + * + * Results are stored in each byte's most-significant bit (MSB). All + * other bits are zero. + */ +__device__ __forceinline__ uint32_t bytewise_less_than(uint32_t a, uint32_t b) { + // Compare low bits by masking MSBs and subtracting. The resulting + // MSBs are 0 if the low bits of a are less than the low bits of b. + uint32_t result = (a | 0x80808080) - (b & 0x7F7F7F7F); + +#ifndef __HIP_PLATFORM_AMD__ + // Bitwise logical op to get answer in MSBs + // Equivalent logic: result = (a == b) ? !result : b + asm("lop3.b32 %0, %1, %2, %3, 0x4D;\n\t" : "=r"(result) : "r"(a), "r"(b), "r"(result)); +#else + // AMD GPU: Use bitwise ops to get answer in MSBs + uint32_t mask = (a ^ b); + result = (mask & b) | ~(mask | result); +#endif + + // Mask out everything except MSBs and return + result &= 0x80808080; + return result; +} + +/* Generate dropout mask with 16 bits. + * + * 1 corresponds to keep and 0 to drop. + * + * Consumes 4 values from cuRAND Philox generator. + */ +__device__ __forceinline__ uint16_t make_16bit_mask(uint64_t chunk_idx, uint64_t rng_seed, + uint64_t rng_offset, + uint32_t bytewise_drop_prob) { + // Generate random bits + curandStatePhilox4_32_10_t state; + curand_init(rng_seed, chunk_idx, rng_offset, &state); + const uint4 rand_bits = curand4(&state); + + // Compute mask + // Note: bytewise_less_than fills MSBs (bits 7, 15, 23, 31). By + // shifting 2 bits after every call, every other bit will be filled. + uint32_t result = bytewise_less_than(rand_bits.x, bytewise_drop_prob); + result = (result >> 2) | bytewise_less_than(rand_bits.y, bytewise_drop_prob); + result = (result >> 2) | bytewise_less_than(rand_bits.z, bytewise_drop_prob); + result = (result >> 2) | bytewise_less_than(rand_bits.w, bytewise_drop_prob); + + // Consolidate mask in lowest 16 bits + result |= result >> 17; + + // Flip bits so 0 corresponds to drop + result = ~result; + + return result; +} + +// Dropout forward with FP16/BF16 input and output. +template +__global__ void __launch_bounds__(block_size) + dropout_kernel_fwd_f16(const T *__restrict__ input_ptr, T *__restrict__ output_ptr, + uint8_t *__restrict__ mask_ptr, + const uint64_t *__restrict__ rng_state_ptr, size_t num_chunks, + uint32_t bytewise_drop_prob, float scale) { + static_assert(sizeof(T) == 2); + + // Each thread processes a chunk of 16 entries + const size_t gid = threadIdx.x + blockIdx.x * block_size; + const size_t nthreads = gridDim.x * block_size; + for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) { + // Generate dropout mask + auto local_mask = + make_16bit_mask(chunk_idx, rng_state_ptr[0], rng_state_ptr[1], bytewise_drop_prob); + reinterpret_cast(mask_ptr)[chunk_idx] = local_mask; + + // Read input data + using VectorType = Vector; + VectorType local_data; + local_data = reinterpret_cast(input_ptr)[chunk_idx]; + + // Apply dropout based on mask +#pragma unroll + for (size_t i = 0; i < rng_chunk_size; i++) { + float val = static_cast(local_data.entries[i]); + if ((local_mask & 0x1) == 0) { + val = 0; + } + val *= scale; + local_data.entries[i] = static_cast(val); + local_mask >>= 1; + } + + // Write output data + reinterpret_cast(output_ptr)[chunk_idx] = local_data; + } +} + +// Dropout forward with FP8 input and FP16/BF16 output. +template +__global__ void __launch_bounds__(block_size) + dropout_kernel_fwd_fp8(const InputType *__restrict__ input_ptr, + const float *__restrict__ input_scale_inv_ptr, + OutputType *__restrict__ output_ptr, uint8_t *__restrict__ mask_ptr, + const uint64_t *__restrict__ rng_state_ptr, size_t num_chunks, + uint32_t bytewise_drop_prob, float scale) { + static_assert(sizeof(InputType) == 1); + static_assert(sizeof(OutputType) == 2); + const float input_scale_inv = *input_scale_inv_ptr; + + // Each thread processes a chunk of 16 entries + const size_t gid = threadIdx.x + blockIdx.x * block_size; + const size_t nthreads = gridDim.x * block_size; + for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) { + // Generate dropout mask + auto local_mask = + make_16bit_mask(chunk_idx, rng_state_ptr[0], rng_state_ptr[1], bytewise_drop_prob); + reinterpret_cast(mask_ptr)[chunk_idx] = local_mask; + + // Read input data + using InputVectorType = Vector; + InputVectorType local_input; + local_input = reinterpret_cast(input_ptr)[chunk_idx]; + + // Apply dropout based on mask + using OutputVectorType = Vector; + OutputVectorType local_output; +#pragma unroll + for (size_t i = 0; i < rng_chunk_size; i++) { + float val = static_cast(local_input.entries[i]); + val *= input_scale_inv; + if ((local_mask & 0x1) == 0) { + val = 0; + } + val *= scale; + local_output.entries[i] = static_cast(val); + local_mask >>= 1; + } + + // Write output data + reinterpret_cast(output_ptr)[chunk_idx] = local_output; + } +} + +// Apply dropout mask and scale. +template +__global__ void __launch_bounds__(block_size) + apply_dropout_mask(const T *__restrict__ input_ptr, const uint8_t *__restrict__ mask_ptr, + T *__restrict__ output_ptr, size_t num_chunks, float scale) { + // Each thread processes a chunk of 8 entries. + const size_t gid = threadIdx.x + blockIdx.x * block_size; + const size_t nthreads = gridDim.x * block_size; + constexpr size_t chunk_size = 8; + for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) { + // Read dropout mask + uint8_t local_mask = mask_ptr[chunk_idx]; + + // Read input data + using VectorType = Vector; + VectorType local_data; + local_data = reinterpret_cast(input_ptr)[chunk_idx]; + + // Apply dropout based on mask +#pragma unroll + for (size_t i = 0; i < chunk_size; i++) { + float val = static_cast(local_data.entries[i]); + if ((local_mask & 0x1) == 0) { + val = 0; + } + val *= scale; + local_data.entries[i] = static_cast(val); + local_mask >>= 1; + } + + // Write output data + reinterpret_cast(output_ptr)[chunk_idx] = local_data; + } +} + +} // namespace + +void dropout_fwd(const Tensor &input, Tensor &output, Tensor &mask, Tensor &rng_state, + float dropout_probability, cudaStream_t stream) { + // Check tensors + const size_t numel = input.numel(); + NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be FP16/BF16 tensor or tensor-scaled FP8 tensor, ", + "but scaling mode is ", to_string(input.scaling_mode), "."); + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Output tensor must be FP16/BF16 tensor, ", "but scaling mode is ", + to_string(output.scaling_mode), "."); + NVTE_CHECK(mask.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Mask tensor must be plain tensor, ", + "but scaling mode is ", to_string(mask.scaling_mode), "."); + NVTE_CHECK(rng_state.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "RNG state tensor must be INT64 tensor with two entries, ", "but scaling mode is ", + to_string(rng_state.scaling_mode), "."); + NVTE_CHECK(output.dtype() == DType::kFloat16 || output.dtype() == DType::kBFloat16, + "Output tensor must be FP16/BF16 tensor, but dtype is ", to_string(output.dtype()), + "."); + NVTE_CHECK(rng_state.dtype() == DType::kInt64, + "RNG state tensor must be INT64 tensor with two entries, but dtype is ", + to_string(rng_state.dtype()), "."); + NVTE_CHECK(numel % 16 == 0, + "Input tensor number of elements must be divisible by 16, but shape is ", + input.shape(), "."); + NVTE_CHECK(numel == output.numel(), "Input tensor (shape=", input.shape(), + ") and output tensor (shape=", output.shape(), ") do not match."); + NVTE_CHECK(typeToNumBits(mask.dtype()) * mask.numel() == numel, "Mask tensor must have ", numel, + " bits, but found dtype=", to_string(mask.dtype()), " and shape=", mask.shape(), "."); + NVTE_CHECK(rng_state.numel() == 2, "RNG state tensor must be INT64 tensor with two entries, ", + "but shape is ", rng_state.shape(), "."); + NVTE_CHECK(input.data.dptr != nullptr, "Input tensor is missing data."); + NVTE_CHECK(output.data.dptr != nullptr, "Output tensor is missing data."); + NVTE_CHECK(mask.data.dptr != nullptr, "Mask tensor is missing data."); + NVTE_CHECK(rng_state.data.dptr != nullptr, "RNG state tensor is missing data."); + + // Convert dropout probablity to scale and 8-bit representation + NVTE_CHECK(dropout_probability >= 0 && dropout_probability < 1, "Invalid dropout probability (", + dropout_probability, ")."); + const float scale = 1 / (1 - dropout_probability); + uint32_t bytewise_drop_prob = static_cast(std::floor(dropout_probability * 256)); + bytewise_drop_prob |= bytewise_drop_prob << 8; + bytewise_drop_prob |= bytewise_drop_prob << 16; + + // CUDA config + const size_t num_chunks = numel / rng_chunk_size; + const size_t num_blocks = DIVUP(num_chunks, block_size); + + // Launch kernel depending on input dtype + if (input.dtype() == DType::kFloat16 || input.dtype() == DType::kBFloat16) { + NVTE_CHECK(input.dtype() == output.dtype(), "Input tensor (dtype=", to_string(input.dtype()), + ") and output tensor (dtype=", to_string(output.dtype()), ") do not match."); + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + input.dtype(), DType, + dropout_kernel_fwd_f16<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), + reinterpret_cast(mask.data.dptr), + reinterpret_cast(rng_state.data.dptr), num_chunks, bytewise_drop_prob, + scale);); + NVTE_CHECK_CUDA(cudaGetLastError()); + } else if (input.dtype() == DType::kFloat8E4M3 || input.dtype() == DType::kFloat8E5M2) { + NVTE_CHECK(input.scale_inv.dptr != nullptr, "Input tensor scale-inverse is not allocated."); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.dtype(), InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + output.dtype(), OutputType, + dropout_kernel_fwd_fp8<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(output.data.dptr), + reinterpret_cast(mask.data.dptr), + reinterpret_cast(rng_state.data.dptr), num_chunks, + bytewise_drop_prob, scale); + + );); + NVTE_CHECK_CUDA(cudaGetLastError()); + } else { + NVTE_ERROR("Input tensor must be FP16/BF16 tensor or tensor-scaled FP8 tensor, ", + "but dtype is ", to_string(input.dtype()), "."); + } +} + +void dropout_bwd(const Tensor &grad_output, const Tensor &mask, Tensor &grad_input, + float dropout_probability, cudaStream_t stream) { + // Check tensors + const size_t numel = grad_output.numel(); + NVTE_CHECK(grad_output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Grad output tensor must be FP16/BF16 tensor, ", "but scaling mode is ", + to_string(grad_output.scaling_mode), "."); + NVTE_CHECK(grad_input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Grad input tensor must be FP16/BF16 tensor, ", "but scaling mode is ", + to_string(grad_input.scaling_mode), "."); + NVTE_CHECK(mask.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Mask tensor must be a plain tensor, but scaling mode is ", + to_string(mask.scaling_mode), "."); + NVTE_CHECK(grad_output.dtype() == DType::kFloat16 || grad_output.dtype() == DType::kBFloat16, + "Grad output tensor must be FP16/BF16 tensor, but dtype is ", + to_string(grad_output.dtype()), "."); + NVTE_CHECK(grad_output.dtype() == grad_input.dtype(), + "Grad output tensor (dtype=", to_string(grad_output.dtype()), + ") and grad input tensor (dtype=", to_string(grad_input.dtype()), ") do not match."); + NVTE_CHECK(numel % 16 == 0, + "Grad output tensor number of elements must be divisible by 16, but shape is ", + grad_output.shape(), "."); + NVTE_CHECK(numel == grad_input.numel(), "Grad output tensor (shape=", grad_output.shape(), + ") and grad input tensor (shape=", grad_input.shape(), ") do not match."); + NVTE_CHECK(typeToNumBits(mask.dtype()) * mask.numel() == numel, "Mask tensor must have ", numel, + " bits, but found dtype=", to_string(mask.dtype()), " and shape=", mask.shape(), "."); + NVTE_CHECK(grad_output.data.dptr != nullptr, "Grad output tensor is missing data."); + NVTE_CHECK(grad_input.data.dptr != nullptr, "Grad input tensor is missing data."); + NVTE_CHECK(mask.data.dptr != nullptr, "Mask tensor is missing data."); + + // Convert dropout probablity to scale + NVTE_CHECK(dropout_probability >= 0 && dropout_probability < 1, "Invalid dropout probability (", + dropout_probability, ")."); + const float scale = 1 / (1 - dropout_probability); + + // CUDA config + const size_t num_chunks = numel / 8; + const size_t num_blocks = DIVUP(num_chunks, block_size); + + // Launch kernel + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + grad_output.dtype(), DType, + apply_dropout_mask<<>>( + reinterpret_cast(grad_output.data.dptr), + reinterpret_cast(mask.data.dptr), + reinterpret_cast(grad_input.data.dptr), num_chunks, scale);); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace transformer_engine + +void nvte_dropout_fwd(const NVTETensor input, NVTETensor output, NVTETensor mask, + NVTETensor rng_state, float dropout_probability, cudaStream_t stream) { + NVTE_API_CALL(nvte_dropout_fwd); + using namespace transformer_engine; + dropout_fwd(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + *convertNVTETensorCheck(mask), *convertNVTETensorCheck(rng_state), + dropout_probability, stream); +} + +void nvte_dropout_bwd(const NVTETensor grad_output, const NVTETensor mask, NVTETensor grad_input, + float dropout_probability, cudaStream_t stream) { + NVTE_API_CALL(nvte_dropout_bwd); + using namespace transformer_engine; + dropout_bwd(*convertNVTETensorCheck(grad_output), *convertNVTETensorCheck(mask), + *convertNVTETensorCheck(grad_input), dropout_probability, stream); +} diff --git a/transformer_engine/common/fused_attn/context_parallel.cu b/transformer_engine/common/fused_attn/context_parallel.cu index 15708d2d5..5921d97d5 100644 --- a/transformer_engine/common/fused_attn/context_parallel.cu +++ b/transformer_engine/common/fused_attn/context_parallel.cu @@ -341,6 +341,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor thd_read_half_tensor_kernel<<>>( half.data.dptr, tensor.data.dptr, reinterpret_cast(cu_seqlens.data.dptr), batch, hidden_size_in_bytes, half_idx, tensor_shape[seq_dim]); + NVTE_CHECK_CUDA(cudaGetLastError()); } /*************************************************************************************************** @@ -397,11 +398,13 @@ void thd_second_half_lse_correction(Tensor lse, const Tensor &lse_per_step, reinterpret_cast(lse.data.dptr), reinterpret_cast(lse_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, second_half_lse_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { thd_lse_kernel<<>>( reinterpret_cast(lse.data.dptr), reinterpret_cast(lse_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, second_half_lse_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -446,11 +449,13 @@ void thd_read_second_half_lse(const Tensor &lse, const Tensor &cu_seqlens, Tenso reinterpret_cast(lse.data.dptr), reinterpret_cast(half_lse.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, second_half_lse_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { thd_lse_kernel<<>>( reinterpret_cast(lse.data.dptr), reinterpret_cast(half_lse.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, second_half_lse_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -519,6 +524,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co reinterpret_cast(lse_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, dim_per_head, lse_seqlen, lse_per_step_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { thd_out_correction_kernel <<>>( @@ -528,6 +534,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co reinterpret_cast(lse_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, dim_per_head, lse_seqlen, lse_per_step_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -602,6 +609,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step, reinterpret_cast(grad.data.dptr), reinterpret_cast(grad_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, hidden_size, total_tokens); + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -667,6 +675,7 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to thd_partition_indices_kernel<<>>( reinterpret_cast(output.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, total_tokens, world_size, rank); + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace context_parallel diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index 0c261d0fa..59207d59a 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -91,6 +91,7 @@ void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) { prepare_kernel_fwd<<>>( reinterpret_cast(qkvi.data.dptr), reinterpret_cast(qkv.data.dptr), shape[1], shape[2], shape[3], shape[4]);); + NVTE_CHECK_CUDA(cudaGetLastError()); } void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) { @@ -129,6 +130,7 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), reinterpret_cast(v.data.dptr), reinterpret_cast(qkv.data.dptr), q_shape[0], q_shape[1], q_shape[2], q_shape[3]);); + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace flash_attention diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 9d4701730..795697635 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset) { + !requires_64bit_ragged_offset && + // 9.10.0: known bugs with SDPA FP8 + (cudnn_runtime_version != 91000)) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else { @@ -239,19 +241,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || - // 9.10: any head_dim + any arch + fprop + paged - // 9.10: any head_dim + any arch + fprop + non_paged + sq > 1 - // 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} - (!is_training && cudnn_runtime_version >= 91000 && + // 9.10.2: any head_dim + any arch + fprop + paged + // 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1 + // 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} + (!is_training && cudnn_runtime_version >= 91002 && (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 || (max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 91100)) && - // 9.11 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA - (!(cudnn_runtime_version == 91100 && is_training && sm_arch_ == 90 && head_dim_qk >= 128 && - head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && + // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA + // Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed + (!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 && + head_dim_qk >= 128 && head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || @@ -358,7 +361,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)))) && // check 64-bit ragged offset support - (supported_ragged_offset_size)) { + (supported_ragged_offset_size) && + // 9.10.0/9.10.1: known bugs with SDPA F16 + (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 0932b2cf8..4e6c3c858 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -416,6 +416,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( actual_b, b, static_cast(devPtrCuSeqlensQ), static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); + NVTE_CHECK_CUDA(cudaGetLastError()); variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } @@ -454,6 +455,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); + NVTE_CHECK_CUDA(cudaGetLastError()); if (is_ragged_q) { variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_o] = devOffsetsO; @@ -883,6 +885,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( actual_b, b, static_cast(devPtrCuSeqlensQ), static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); + NVTE_CHECK_CUDA(cudaGetLastError()); variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } @@ -916,6 +919,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); + NVTE_CHECK_CUDA(cudaGetLastError()); if (is_ragged_q) { variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_o] = devOffsetsO; diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 3e38a5066..d7f098376 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1111,6 +1111,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in cu_seqlens_to_offsets<<>>( b, h, d, reinterpret_cast(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, o_ragged_offset); + NVTE_CHECK_CUDA(cudaGetLastError()); void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); @@ -1577,6 +1578,7 @@ void fused_attn_fp8_bwd_impl( cu_seqlens_to_offsets<<>>( b, h, d, reinterpret_cast(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, o_ragged_offset); + NVTE_CHECK_CUDA(cudaGetLastError()); void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); @@ -1933,6 +1935,7 @@ void fused_attn_fp8_fwd_impl_v1( b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); + NVTE_CHECK_CUDA(cudaGetLastError()); variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } @@ -2329,6 +2332,7 @@ void fused_attn_fp8_bwd_impl_v1( b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); + NVTE_CHECK_CUDA(cudaGetLastError()); variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } diff --git a/transformer_engine/common/fused_attn/kv_cache.cu b/transformer_engine/common/fused_attn/kv_cache.cu index 9bdc41e9e..67119c323 100644 --- a/transformer_engine/common/fused_attn/kv_cache.cu +++ b/transformer_engine/common/fused_attn/kv_cache.cu @@ -157,6 +157,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso reinterpret_cast(page_table.data.dptr), reinterpret_cast(cu_new_lens.data.dptr), reinterpret_cast(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len); + NVTE_CHECK_CUDA(cudaGetLastError()); } dim3 grid_size(b, max_ctx_len); copy_to_kv_cache_kernel<<>>( @@ -166,6 +167,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso reinterpret_cast(cu_new_lens.data.dptr), reinterpret_cast(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -215,6 +217,7 @@ void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se reinterpret_cast(tensor.data.dptr), reinterpret_cast(new_tensor.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), b, max_seq_len, h, d); + NVTE_CHECK_CUDA(cudaGetLastError()); } void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b, @@ -254,6 +257,7 @@ void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se reinterpret_cast(tensor.data.dptr), reinterpret_cast(new_tensor.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), b, max_seq_len, h, d); + NVTE_CHECK_CUDA(cudaGetLastError()); } void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index c790b1132..4803b5dc6 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -604,13 +604,14 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud // workspace size requires 4 bytes uint32_t *dout = static_cast(workspace); uint32_t hout{}; - cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream)); constexpr int threads = 128; const int blocks = (len - 1) / threads + 1; get_runtime_num_segments_kernel<<>>(static_cast(cu_seqlen), len, dout); - cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream); - cudaStreamSynchronize(stream); + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); return hout; } @@ -637,4 +638,5 @@ void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>( rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph); + NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index df9ea6ee5..ccd0bc44c 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -21,12 +21,21 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs const int h, const int d, const int d2, const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { + extern __shared__ float shared_mem_cos_sin[]; + float *shared_mem_cos = shared_mem_cos_sin; + float *shared_mem_sin = shared_mem_cos_sin + d2; + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + __syncthreads(); + #pragma unroll - for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { - float v_cos, v_sin; - sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos); + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + float v_cos = shared_mem_cos[d_id]; + float v_sin = shared_mem_sin[d_id]; int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; @@ -49,12 +58,12 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs // copy the rest if (d > d2) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { - int offset_head = offset_block + h_id * stride_h; - int offset_head_dst = offset_block_dst + h_id * o_stride_h; + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { #pragma unroll - for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { - dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + dst[offset_dst] = src[offset_src]; } } } @@ -67,47 +76,54 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq const int h, const int d, const int d2, const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { + extern __shared__ float shared_mem_cos_sin[]; + float *shared_mem_cos = shared_mem_cos_sin; + float *shared_mem_sin = shared_mem_cos_sin + d2; + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + __syncthreads(); + #pragma unroll - for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { - float v_cos = cosf(freqs[s_id * d2 + d_id]); - float v_sin; - if (!interleaved) { - v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) - : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); - } else { - v_sin = - (d_id % 2 == 0) ? sinf(freqs[s_id * d2 + d_id + 1]) : -sinf(freqs[s_id * d2 + d_id - 1]); - } + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; - float v_src_rotate; + float v_cos = shared_mem_cos[d_id]; + float v_src_rotate, v_sin; if (!interleaved) { - v_src_rotate = (d_id + d2 / 2 < d2) - ? static_cast(src[offset_src + (d2 / 2) * stride_d]) - : static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + if (d_id + d2 / 2 < d2) { + v_src_rotate = static_cast(src[offset_src + (d2 / 2) * stride_d]); + v_sin = shared_mem_sin[d_id + d2 / 2]; + } else { + v_src_rotate = static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + v_sin = -shared_mem_sin[d_id + d2 / 2 - d2]; + } } else { - v_src_rotate = (d_id % 2 == 0) - // d_id + 1 - ? static_cast(src[offset_src + stride_d]) - // d_id - 1 - : static_cast(src[offset_src - stride_d]); + if (d_id % 2 == 0) { + v_src_rotate = static_cast(src[offset_src + stride_d]); + v_sin = shared_mem_sin[d_id + 1]; + } else { + v_src_rotate = static_cast(src[offset_src - stride_d]); + v_sin = -shared_mem_sin[d_id - 1]; + } } dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; } } - // handle the tail + // copy the rest if (d > d2) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { - int offset_head = offset_block + h_id * stride_h; - int offset_head_dst = offset_block_dst + h_id * o_stride_h; + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { #pragma unroll - for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { - dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + dst[offset_dst] = src[offset_src]; } } } @@ -198,6 +214,251 @@ __global__ void fused_rope_backward_kernel( offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } +template +__device__ void fused_qkv_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *out, + const bool interleaved, const int s_id, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, + const int row_offset, const int in_row_length, + const int out_row_length) { + extern __shared__ float shared_mem_cos_sin_qk[]; + // Split the shared memory into cos and sin parts for q or k + float *shared_mem_cos = nullptr; + float *shared_mem_sin = nullptr; + if (row_offset == 0) { // q + shared_mem_cos = shared_mem_cos_sin_qk; + shared_mem_sin = shared_mem_cos_sin_qk + d2; + } else { // k + shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2; + shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2; + } + if (freqs != nullptr) { + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + } + __syncthreads(); + +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id; + if (freqs != nullptr) { + float v_cos, v_sin; + v_cos = shared_mem_cos[d_id]; + v_sin = shared_mem_sin[d_id]; + float v_src = src[offset_src]; + float v_src_rotate; + if (!interleaved) { + v_src_rotate = (d_id + d2 / 2 < d2) + ? -static_cast(src[offset_src + (d2 / 2)]) + : static_cast(src[offset_src + (d2 / 2 - d2)]); + } else { + v_src_rotate = (d_id % 2 == 0) ? -static_cast(src[offset_src + 1]) + : static_cast(src[offset_src - 1]); + } + out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } else { + out[offset_dst] = src[offset_src]; + } + } + } + } + // copy the rest + if (d > d2) { +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { + int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id; + out[offset_dst] = src[offset_src]; + } + } + } + } +} + +template +__device__ void fused_qkv_rope_block_backward(const scalar_t *grad_out, const float *freqs, + scalar_t *out, const bool interleaved, const int s_id, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, + const int row_offset, const int in_row_length, + const int out_row_length) { + extern __shared__ float shared_mem_cos_sin_qk[]; + float *shared_mem_cos = nullptr; + float *shared_mem_sin = nullptr; + // Split the shared memory into cos and sin parts for q or k + if (row_offset == 0) { // q + shared_mem_cos = shared_mem_cos_sin_qk; + shared_mem_sin = shared_mem_cos_sin_qk + d2; + } else { // k + shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2; + shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2; + } + if (freqs != nullptr) { + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + } + __syncthreads(); +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_src = offset_block_dst + h_id * out_row_length + i + d_id; + + float v_src = grad_out[offset_src]; + if (freqs != nullptr) { + float v_cos, v_sin; + v_cos = shared_mem_cos[d_id]; + float v_src_rotate; + if (!interleaved) { + if (d_id + d2 / 2 < d2) { + v_src_rotate = static_cast(grad_out[offset_src + (d2 / 2)]); + v_sin = shared_mem_sin[d_id + d2 / 2]; + } else { + v_src_rotate = static_cast(grad_out[offset_src + (d2 / 2 - d2)]); + v_sin = -shared_mem_sin[d_id + d2 / 2 - d2]; + } + } else { + if (d_id % 2 == 0) { + v_src_rotate = static_cast(grad_out[offset_src + 1]); + v_sin = shared_mem_sin[d_id + 1]; + } else { + v_src_rotate = static_cast(grad_out[offset_src - 1]); + v_sin = -shared_mem_sin[d_id - 1]; + } + } + out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } else { + out[offset_dst] = grad_out[offset_src]; + } + } + } + } + // copy the rest + if (d > d2) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_src = offset_block_dst + h_id * out_row_length + i + d_id; + out[offset_dst] = grad_out[offset_src]; + } + } + } + } +} + +template +__global__ void fused_qkv_rope_forward_kernel( + const scalar_t *qkv_input, const float *q_freqs, const float *k_freqs, + const int *start_positions, scalar_t *q_out, scalar_t *k_out, scalar_t *v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, const int q_split_arg, + const int k_split_arg, const int v_split_arg) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int cur_seqlens = s; + int total_d = q_split_arg + k_split_arg + v_split_arg; + int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v; + if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + offset_block = s_id * b * h * total_d + b_id * h * total_d; + offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg; + offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg; + offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg; + } else { + offset_block = b_id * s * h * total_d + s_id * h * total_d; + offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg; + offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg; + offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg; + } + + int q_limit = q_split_arg; + int k_limit = q_limit + k_split_arg; + int s_id_for_freqs; + if (cp_size > 1) { + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + s_id_for_freqs = s_id + begin_offset; + } + fused_qkv_rope_block_forward(qkv_input, q_freqs, q_out, interleaved, s_id_for_freqs, offset_block, + offset_block_dst_q, h, d, d2, 0, total_d, q_split_arg); + fused_qkv_rope_block_forward(qkv_input, k_freqs, k_out, interleaved, s_id_for_freqs, offset_block, + offset_block_dst_k, h, d, d2, q_limit, total_d, k_split_arg); + fused_qkv_rope_block_forward(qkv_input, nullptr, v_out, interleaved, s_id_for_freqs, offset_block, + offset_block_dst_v, h, d, d2, k_limit, total_d, v_split_arg); +} + +template +__global__ void fused_qkv_rope_backward_kernel( + const scalar_t *grad_out_q, const scalar_t *grad_out_k, const scalar_t *grad_out_v, + const float *q_freqs, const float *k_freqs, scalar_t *qkv_grad, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, const int q_split_arg, + const int k_split_arg, const int v_split_arg) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int cur_seqlens = s; + int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v; + int total_d = q_split_arg + k_split_arg + v_split_arg; + if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + offset_block = s_id * b * h * total_d + b_id * h * total_d; + offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg; + offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg; + offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg; + } else { + offset_block = b_id * s * h * total_d + s_id * h * total_d; + offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg; + offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg; + offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg; + } + int q_limit = q_split_arg; + int k_limit = q_limit + k_split_arg; + int s_id_for_freqs; + if (cp_size > 1) { + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + s_id_for_freqs = s_id; + } + fused_qkv_rope_block_backward(grad_out_q, q_freqs, qkv_grad, interleaved, s_id_for_freqs, + offset_block, offset_block_dst_q, h, d, d2, 0, total_d, + q_split_arg); + fused_qkv_rope_block_backward(grad_out_k, k_freqs, qkv_grad, interleaved, s_id_for_freqs, + offset_block, offset_block_dst_k, h, d, d2, q_limit, total_d, + k_split_arg); + fused_qkv_rope_block_backward(grad_out_v, nullptr, qkv_grad, interleaved, s_id_for_freqs, + offset_block, offset_block_dst_v, h, d, d2, k_limit, total_d, + v_split_arg); +} + template void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs, const int *start_positions, scalar_t *output, @@ -209,6 +470,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin int o_stride_s_or_t, o_stride_b; if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); @@ -224,7 +486,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c const int o_stride_h = d; const int o_stride_d = 1; - fused_rope_forward_kernel<<>>( + fused_rope_forward_kernel<<>>( input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d); @@ -242,6 +504,7 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin int o_stride_s_or_t, o_stride_b; if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); @@ -257,13 +520,58 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se const int o_stride_h = d; const int o_stride_d = 1; - fused_rope_backward_kernel<<>>( + fused_rope_backward_kernel<<>>( output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } +template +void fused_qkv_rope_forward_launcher(const scalar_t *qkv_input, const float *q_freqs, + const float *k_freqs, const int *start_positions, + scalar_t *q_out, scalar_t *k_out, scalar_t *v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + const int THREADS_PER_WARP = 32; + int warps_per_block = (h <= 8) ? h : 8; + dim3 blocks(s, b); + dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k + + fused_qkv_rope_forward_kernel<<>>( + qkv_input, q_freqs, k_freqs, start_positions, q_out, k_out, v_out, qkv_format, interleaved, + cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void fused_qkv_rope_backward_launcher(const scalar_t *q_grad_out, const scalar_t *k_grad_out, + const scalar_t *v_grad_out, const float *q_freqs, + const float *k_freqs, scalar_t *qkv_grad_input, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, + const int b, const int h, const int d, const int d2, + const int qkv_split_arg_list_0, + const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + const int THREADS_PER_WARP = 32; + const int warps_per_block = (h <= 8) ? h : 8; + dim3 blocks(s, b); + dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k + + fused_qkv_rope_backward_kernel<<>>( + q_grad_out, k_grad_out, v_grad_out, q_freqs, k_freqs, qkv_grad_input, qkv_format, interleaved, + cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, const Tensor &start_positions, Tensor *output, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, @@ -297,6 +605,46 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c stride_b, stride_h, stride_d, stream);); } +void fused_qkv_rope_forward(const Tensor &qkv_input, const Tensor &q_freqs, const Tensor &k_freqs, + const Tensor &start_positions, Tensor *q_out, Tensor *k_out, + Tensor *v_out, const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int qkv_split_arg_list_0, + const int qkv_split_arg_list_1, const int qkv_split_arg_list_2, + cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + qkv_input.data.dtype, scalar_t, + fused_qkv_rope_forward_launcher(reinterpret_cast(qkv_input.data.dptr), + reinterpret_cast(q_freqs.data.dptr), + reinterpret_cast(k_freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), + reinterpret_cast(q_out->data.dptr), + reinterpret_cast(k_out->data.dptr), + reinterpret_cast(v_out->data.dptr), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, + qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2, stream);); +} + +void fused_qkv_rope_backward(const Tensor &q_grad_out, const Tensor &k_grad_out, + const Tensor &v_grad_out, const Tensor &q_freqs, const Tensor &k_freqs, + Tensor *qkv_grad_input, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + q_grad_out.data.dtype, scalar_t, + fused_qkv_rope_backward_launcher(reinterpret_cast(q_grad_out.data.dptr), + reinterpret_cast(k_grad_out.data.dptr), + reinterpret_cast(v_grad_out.data.dptr), + reinterpret_cast(q_freqs.data.dptr), + reinterpret_cast(k_freqs.data.dptr), + reinterpret_cast(qkv_grad_input->data.dptr), + qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, + qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2, stream);); +} } // end namespace transformer_engine void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, @@ -328,3 +676,38 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); } + +void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, + const NVTETensor k_freqs, const NVTETensor start_positions, + NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_qkv_rope_forward); + using namespace transformer_engine; + fused_qkv_rope_forward(*convertNVTETensorCheck(qkv_input), *convertNVTETensorCheck(q_freqs), + *convertNVTETensorCheck(k_freqs), *convertNVTETensorCheck(start_positions), + convertNVTETensorCheck(q_out), convertNVTETensorCheck(k_out), + convertNVTETensorCheck(v_out), qkv_format, interleaved, cp_size, cp_rank, + s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2, stream); +} + +void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out, + const NVTETensor v_grad_out, const NVTETensor q_freqs, + const NVTETensor k_freqs, NVTETensor qkv_grad_input, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_qkv_rope_backward); + using namespace transformer_engine; + fused_qkv_rope_backward(*convertNVTETensorCheck(q_grad_out), *convertNVTETensorCheck(k_grad_out), + *convertNVTETensorCheck(v_grad_out), *convertNVTETensorCheck(q_freqs), + *convertNVTETensorCheck(k_freqs), convertNVTETensorCheck(qkv_grad_input), + qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, + qkv_split_arg_list_0, qkv_split_arg_list_1, qkv_split_arg_list_2, stream); +} diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index 3af7e42c2..51ffa7a19 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -92,7 +92,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Reduce to get the sum of aggregated_probs_per_expert */ CompType intermediate_result = - warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id); + warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, ReduceFuncType::SUM, lane_id); __syncwarp(); if (lane_id == 0) { @@ -148,7 +148,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Reduce to get the sum of aggregated_probs_per_expert */ CompType intermediate_result = - warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id); + warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, ReduceFuncType::SUM, lane_id); __syncwarp(); if (lane_id == 0) { @@ -181,9 +181,9 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, config.stream = stream; // Update the max cluster size based on the device - cudaOccupancyMaxPotentialClusterSize( + NVTE_CHECK_CUDA(cudaOccupancyMaxPotentialClusterSize( &cluster_size, - reinterpret_cast(fused_moe_aux_loss_forward_kernel), &config); + reinterpret_cast(fused_moe_aux_loss_forward_kernel), &config)); cudaLaunchAttribute attribute[1]; attribute[0].id = cudaLaunchAttributeClusterDimension; @@ -193,15 +193,16 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, config.numAttrs = 1; config.attrs = attribute; - cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel, probs, - tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk, - coeff, aux_loss, Const_buf); + NVTE_CHECK_CUDA(cudaLaunchKernelEx( + &config, fused_moe_aux_loss_forward_kernel, probs, tokens_per_expert, + total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, aux_loss, Const_buf)); } else { #endif size_t smem_size = sizeof(CompType) * num_cols; fused_moe_aux_loss_forward_kernel <<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, aux_loss, Const_buf); + NVTE_CHECK_CUDA(cudaGetLastError()); #ifndef __HIP_PLATFORM_AMD__ } #endif @@ -235,7 +236,7 @@ __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf, // Loop: for all positions in each row for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { float C_coeff = Const_buf[0]; - IndexType tokens_per_expert_i = tokens_per_expert[i]; + double tokens_per_expert_i = static_cast(tokens_per_expert[i]); double grad_aux_loss_value = static_cast(grad_aux_loss[0]); // Loop: for all rows for (int j = global_warp_id; j < num_rows; j += global_warp_num) { @@ -254,6 +255,7 @@ void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf, int grid_size = (num_rows + block_size - 1) / block_size; fused_moe_aux_loss_backward_kernel<<>>( Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert, diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 91a4bbb53..03d22942b 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -107,7 +107,8 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi if (score_function == 0) { if (topk > 1) { - auto sum_logits = warp_reduce_on_shmem(local_logits, num_experts, sum, lane_id); + auto sum_logits = + warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_logits[i] = static_cast(static_cast(local_logits[i]) / (static_cast(sum_logits) + epsilon)); @@ -150,6 +151,7 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( <<>>( logits, num_tokens, num_experts, topk, score_function, scores, routing_map, intermediate_output); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, int num_experts, @@ -231,13 +233,15 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int */ // Sigmoid Post-processing bwd when topk > 1 if (topk > 1 && score_function == 0) { - auto sum_fwd_input = warp_reduce_on_shmem(local_act_from_fwd, num_experts, sum, lane_id); + auto sum_fwd_input = + warp_reduce_on_shmem(local_act_from_fwd, num_experts, ReduceFuncType::SUM, lane_id); // Put the result of output * grad to the comp_buf for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_comp_buf[i] = local_grad[i] * local_act_from_fwd[i]; } __syncwarp(); - auto sum_Output_x_Grad = warp_reduce_on_shmem(local_comp_buf, num_experts, sum, lane_id); + auto sum_Output_x_Grad = + warp_reduce_on_shmem(local_comp_buf, num_experts, ReduceFuncType::SUM, lane_id); // In-place update for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_grad[i] = @@ -283,6 +287,7 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher( <<>>( intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function, grad_logits); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output, diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 443931aa3..32a78a078 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -227,7 +227,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( // score_function == 0 means sigmoid if (score_function == 0) { if (topk > 1) { - double sum_scores = warp_reduce_on_shmem(topk_scores, topk, sum, lane_id); + double sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id); for (int i = lane_id; i < topk; i += kThreadsPerWarp) { topk_scores[i] = static_cast(topk_scores[i]) / (sum_scores + epsilon); } @@ -264,6 +264,7 @@ void fused_topk_with_score_function_forward_kernel_launcher( <<>>( logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, int num_experts, @@ -369,7 +370,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( /*data ptr = */ local_act_from_fwd, /*mask ptr = */ local_routing_map, /*data size = */ num_experts, - /*reduce func = */ sum, lane_id); + /*reduce func = */ ReduceFuncType::SUM, lane_id); // Put the result of output * grad to the comp_buf for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_comp_buf[i] = (local_routing_map[i] ? static_cast(local_grad[i]) * @@ -381,7 +382,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( /*data ptr = */ local_comp_buf, /*mask ptr = */ local_routing_map, /*data size = */ num_experts, - /*reduce func = */ sum, lane_id); + /*reduce func = */ ReduceFuncType::SUM, lane_id); // In-place update for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { if (local_routing_map[i]) { @@ -454,6 +455,7 @@ void fused_topk_with_score_function_backward_kernel_launcher( <<>>( routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, use_pre_softmax, scaling_factor, score_function, grad_logits); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_topk_with_score_function_backward(const Tensor &routing_map, diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 488daf39b..3026d7e51 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -39,14 +39,28 @@ __device__ inline T sum(T a, T b) { return a + b; } +enum ReduceFuncType { + SUM, + MAX, +}; + template -__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, T (*reduce_func)(T, T), +__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncType type, int lane_id) { + T (*reduce_func)(T, T); + double default_val = 0; + if (type == ReduceFuncType::SUM) { + reduce_func = sum; + default_val = 0; + } else if (type == ReduceFuncType::MAX) { + reduce_func = max; + default_val = -std::numeric_limits::infinity(); + } + // Some value is hanlded in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread - volatile double val = - lane_id < data_size ? static_cast(data_ptr[lane_id]) : static_cast(0); + volatile double val = lane_id < data_size ? static_cast(data_ptr[lane_id]) : default_val; for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { //TODO: release after /opt/rocm/include/hip/amd_detail/amd_hip_bfloat16.h provide bf16 constructor from double #ifdef __HIP_PLATFORM_AMD__ @@ -83,13 +97,22 @@ __device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, i template __device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size, - T (*reduce_func)(T, T), int lane_id) { + ReduceFuncType type, int lane_id) { + T (*reduce_func)(T, T); + double default_val = 0; + if (type == ReduceFuncType::SUM) { + reduce_func = sum; + default_val = 0; + } else if (type == ReduceFuncType::MAX) { + reduce_func = max; + default_val = -std::numeric_limits::infinity(); + } + // Some value is hanlded in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread - volatile double val = lane_id < data_size && mask[lane_id] - ? static_cast(data_ptr[lane_id]) - : static_cast(0); + volatile double val = + lane_id < data_size && mask[lane_id] ? static_cast(data_ptr[lane_id]) : default_val; for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { if (mask[i]) { val = reduce_func(static_cast(val), data_ptr[i]); @@ -142,7 +165,7 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_ float sum_Output_x_Grad = warp_reduce_on_shmem( /*data ptr = */ comp_buf, /*data size = */ data_size, - /*reduce func = */ sum, lane_id); + /*reduce func = */ ReduceFuncType::SUM, lane_id); // In-place update for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { if (mask) { @@ -161,14 +184,16 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_ template __device__ inline void apply_softmax_on_float(DataType *scores, int data_size, int lane_id) { // 1. compute the max of value - float max_val = static_cast(warp_reduce_on_shmem(scores, data_size, max, lane_id)); + float max_val = + static_cast(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::MAX, lane_id)); // 2. value -> exp_value for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { scores[i] = static_cast(exp(static_cast(scores[i]) - max_val)); } __syncwarp(); // 3. compute the sum of exp_value - float sum_val = static_cast(warp_reduce_on_shmem(scores, data_size, sum, lane_id)); + float sum_val = + static_cast(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::SUM, lane_id)); // 4. update the softmax value for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { scores[i] = static_cast(scores[i]) / sum_val; @@ -179,19 +204,29 @@ __device__ inline void apply_softmax_on_float(DataType *scores, int data_size, i template __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices, T *topk_scores, int lane_id) { + // Check if the index is masked by the later iteration + auto is_masked = [&topk_indices](int k, int index) { + if (k == 0) return false; + for (int i = 0; i < k; i++) { + if (topk_indices[i] == index) return true; + } + return false; + }; // Topk Times: Find the max value and its index // Then mask it, and record the index in the topk_indices // After looping topk times, the topk_indices will be the topk indices for (int k = 0; k < topk; k++) { // Find the max value and its index - volatile double val = - (lane_id < data_size) ? static_cast(scores[lane_id]) : static_cast(0); + volatile double val = (lane_id < data_size && !is_masked(k, lane_id)) + ? static_cast(scores[lane_id]) + : -std::numeric_limits::infinity(); volatile int index = (lane_id < data_size) ? lane_id : 0; // Some value is hanlded in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { - volatile double cur_val = scores[i]; + volatile double cur_val = (is_masked(k, i)) ? -std::numeric_limits::infinity() + : static_cast(scores[i]); if (cur_val > val) { val = cur_val; index = i; @@ -214,17 +249,9 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i if (lane_id == 0) { topk_indices[k] = index; topk_scores[k] = val; - scores[index] = - static_cast(-1.0) - val; // make the selected experts using val = - 1 - val } __syncwarp(); } - - // Reset the scores to the original value - for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - scores[topk_indices[i]] = - static_cast(-1.0) - static_cast(scores[topk_indices[i]]); - } } // Current TE only support float32/bf16/fp16, float64 probs should be considered in the future @@ -258,6 +285,14 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i using type = int64_t; \ { __VA_ARGS__ } \ } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ default: \ NVTE_ERROR("Invalid type."); \ } diff --git a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu index a06c8493a..db8749b16 100644 --- a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu @@ -359,6 +359,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_forward( scaled_aligned_causal_masked_softmax_warp_forward <<>>(dst, src, scale, microbatches, query_seq_len, key_seq_len); + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -369,6 +370,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_backward( scaled_aligned_causal_masked_softmax_warp_backward <<>>(gradInput, grad, output, scale, microbatches, query_seq_len, key_seq_len); + NVTE_CHECK_CUDA(cudaGetLastError()); } template diff --git a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu index 5eca4947f..f31dfa08d 100644 --- a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu @@ -519,6 +519,7 @@ void dispatch_scaled_softmax_forward(output_t *dst, const input_t *src, const ac default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -631,6 +632,7 @@ void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, c default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -742,6 +744,7 @@ void dispatch_scaled_masked_softmax_backward(output_t *grad_input, const input_t default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu index af3027039..01ca632a9 100644 --- a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu @@ -451,6 +451,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(output_t *dst, const in default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -567,6 +568,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(output_t *grad_input, default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index cca994299..9c2ca9b4c 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -11,7 +11,7 @@ #include #include #endif // #ifndef __HIP_PLATFORM_AMD__ -#include + #include #include #include @@ -24,28 +24,13 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "common/util/cuda_runtime.h" +#ifndef __HIP_PLATFORM_AMD__ +#include "cutlass_grouped_gemm.cuh" +#endif #ifndef __HIP_PLATFORM_AMD__ namespace { -cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { - using namespace transformer_engine; - switch (t) { - case DType::kFloat16: - return CUDA_R_16F; - case DType::kFloat32: - return CUDA_R_32F; - case DType::kBFloat16: - return CUDA_R_16BF; - case DType::kFloat8E4M3: - return CUDA_R_8F_E4M3; - case DType::kFloat8E5M2: - return CUDA_R_8F_E5M2; - default: - NVTE_ERROR("Invalid type"); - } -} - uint32_t _getAlignment(uintptr_t address) { // alignment are in bytes uint32_t alignment = 256; @@ -244,10 +229,11 @@ namespace transformer_engine { #ifdef __HIP_PLATFORM_AMD__ //Forward declaration. The implementation is in rocm_gemm.cu void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, bool transa, bool transb, bool grad, - void* workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, - int math_sm_count, int m_split, int n_split, bool gemm_producer, - const Tensor *inputCounter, hipStream_t stream, int compute_stream_offset = -1); + const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, + cublasOperation_t transb, bool grad, void* workspace, size_t workspaceSize, + float alpha, float beta, bool use_split_accumulator, int math_sm_count, + int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, + hipStream_t stream, int compute_stream_offset = -1); #else // Use cublasLt using cublasHandleManager = detail::HandleManager; @@ -255,8 +241,9 @@ using cublasHandleManager = detail::HandleManagerflat_first_dim(); const int A1 = inputA->flat_last_dim(); @@ -312,13 +299,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, "fp8 Aux output for gemm + gelu fusion not supported!"); } if (is_fp8_dtype(outputD->data.dtype)) { - NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!"); + NVTE_CHECK(beta == 0.0f, "Accumulation mode not supported with FP8 GEMM output!"); } - float one = 1.0; - float zero = 0.0; - float beta = (accumulate) ? one : zero; - cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); cublasLtMatmulDesc_t operationDesc = nullptr; @@ -537,22 +520,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, &epilogue, sizeof(epilogue))); if (counter != nullptr) { -#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000) - NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ", +#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) + NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); #endif #if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) NVTE_ERROR( - "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ", + "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", CUBLAS_VERSION); #endif #if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \ CUBLAS_VERSION < 130000 NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, - "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is ", + "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ", cuda::cudart_version()); NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000, - "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ", + "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", cublas_version()); if (m_split == 0) m_split = 1; if (n_split == 0) n_split = 1; @@ -603,7 +586,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, // D = alpha * (A * B) + beta * C NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, - static_cast(&one), /* alpha */ + static_cast(&alpha), /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ Bdesc, static_cast(&beta), /* beta */ @@ -632,38 +615,30 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } // namespace transformer_engine -// compute_stream_offset = -1 means the stream from outer rather than compute_streams -static void cublas_gemm_ex(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, - NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, - NVTETensor workspace, bool accumulate, bool use_split_accumulator, - int math_sm_count, cudaStream_t stream, int compute_stream_offset = -1) { +void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, + NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, + NVTETensor workspace, bool accumulate, bool use_split_accumulator, + int math_sm_count, cudaStream_t stream) { + NVTE_API_CALL(nvte_cublas_gemm); using namespace transformer_engine; const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); - Tensor *outputD = convertNVTETensorCheck(D); - const Tensor *biasTensor = convertNVTETensorCheck(bias); - Tensor *outputGelu = convertNVTETensorCheck(pre_gelu_out); - Tensor *wspace = convertNVTETensorCheck(workspace); + Tensor *outputD = convertNVTETensor(D); + const Tensor *biasTensor = convertNVTETensor(bias); + Tensor *outputGelu = convertNVTETensor(pre_gelu_out); + Tensor *wspace = convertNVTETensor(workspace); - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, -#ifdef __HIP_PLATFORM_AMD__ - transa, transb, -#else - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, -#endif //__HIP_PLATFORM_AMD__ - grad, wspace->data.dptr, wspace->data.shape[0], - accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream -#ifdef __HIP_PLATFORM_AMD__ - , compute_stream_offset -#endif //__HIP_PLATFORM_AMD__ - ); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, 0, 0, false, + nullptr, stream); } -void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, - NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, - NVTETensor workspace, bool accumulate, bool use_split_accumulator, - int math_sm_count, cudaStream_t stream) { - NVTE_API_CALL(nvte_cublas_gemm); +void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D, + const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, + bool transb, bool grad, NVTETensor workspace, float alpha, float beta, + bool use_split_accumulator, int math_sm_count, cudaStream_t stream) { + NVTE_API_CALL(nvte_cublas_gemm_scaled); using namespace transformer_engine; const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); @@ -672,15 +647,9 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons Tensor *outputGelu = convertNVTETensor(pre_gelu_out); Tensor *wspace = convertNVTETensor(workspace); - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, -#ifdef __HIP_PLATFORM_AMD__ - transa, transb, -#else - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, - (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, -#endif - grad, wspace->data.dptr, wspace->data.shape[0], - accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + alpha, beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, @@ -694,20 +663,23 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor #ifndef __HIP_PLATFORM_AMD__ // Check CUDA and cuBLAS versions -#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000) - NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ", +#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) + NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); #endif #if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) - NVTE_ERROR("Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ", - CUBLAS_VERSION); + NVTE_ERROR( + "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", + CUBLAS_VERSION); #endif - NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, - "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is ", - cuda::cudart_version()); + NVTE_CHECK( + transformer_engine::cuda::cudart_version() >= 12020 && + transformer_engine::cuda::cudart_version() < 13000, + "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ", + transformer_engine::cuda::cudart_version()); NVTE_CHECK( cublas_version() >= 120205 && cublas_version() < 130000, - "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ", + "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", cublas_version()); #endif //__HIP_PLATFORM_AMD__ @@ -722,24 +694,17 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, -#ifdef __HIP_PLATFORM_AMD__ - transa, transb, -#else - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, -#endif //__HIP_PLATFORM_AMD__ - grad, wspace->data.dptr, wspace->data.shape[0], - accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, - inputCounter, stream); + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, m_split, + n_split, gemm_producer, inputCounter, stream); } -void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, - const NVTETensor *bias, NVTETensor *pre_gelu_out, - const int num_gemms, bool transa, bool transb, bool grad, - NVTETensor *workspace, bool accumulate, - bool use_split_accumulator, int math_sm_count, - cudaStream_t stream) { - NVTE_API_CALL(nvte_multi_stream_cublas_gemm); +void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor *workspace, + bool accumulate, bool use_split_accumulator, int math_sm_count, + cudaStream_t stream) { using namespace transformer_engine; int num_streams = nvte_get_num_compute_streams(); @@ -753,9 +718,26 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT } for (int i = 0; i < num_gemms; i++) { - cublas_gemm_ex(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, - workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, - detail::get_compute_stream(i % num_streams), i % num_streams); +#ifdef __HIP_PLATFORM_AMD__ + { + const Tensor *inputA = convertNVTETensorCheck(A[i]); + const Tensor *inputB = convertNVTETensorCheck(B[i]); + Tensor *outputD = convertNVTETensorCheck(D[i]); + const Tensor *biasTensor = convertNVTETensorCheck(bias[i]); + Tensor *outputGelu = convertNVTETensorCheck(pre_gelu_out[i]); + Tensor *wspace = convertNVTETensorCheck(workspace[i % num_streams]); + + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, + (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, + wspace->data.dptr, wspace->data.shape[0], 1.0f, (accumulate) ? 1.0f : 0.0f, + use_split_accumulator, math_sm_count, 0, 0, false, nullptr, + detail::get_compute_stream(i % num_streams), i % num_streams); + } +#else + nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, + workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, + detail::get_compute_stream(i % num_streams)); +#endif } // record events on compute streams @@ -769,6 +751,25 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT } } +void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, + const int num_gemms, bool transa, bool transb, bool grad, + NVTETensor *workspace, bool accumulate, + bool use_split_accumulator, int math_sm_count, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_stream_cublas_gemm); + using namespace transformer_engine; + + // Deprecation warning + NVTE_WARN( + "nvte_multi_stream_cublas_gemm is deprecated and will be removed in a future release. " + "Please migrate to nvte_multi_tensor_gemm (with CUTLASS Grouped GEMM support when " + "applicable)."); + + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace, + accumulate, use_split_accumulator, math_sm_count, stream); +} + #ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { @@ -777,4 +778,91 @@ using cublasHandleManager = detail::HandleManager("NVTE_USE_CUTLASS_GROUPED_GEMM", false); + const bool warn_fallback = + transformer_engine::getenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false); + + auto cublas_path = [&]() { + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, + workspace, accumulate, use_split_accumulator, math_sm_count, stream); + }; + + // Currently only support cutlass group gemm on Hopper Arch + if (!(is_hopper && use_cutlass)) { + cublas_path(); + return; + } + + auto is_empty_arr = [&](const NVTETensor *p) -> bool { + if (p == nullptr) return true; + for (int i = 0; i < num_gemms; ++i) { + if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false; + } + return true; + }; + + auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool { + int64_t ref_k = -1; + for (size_t i = 0; i < num_gemms; i++) { + const auto tensor = transformer_engine::convertNVTETensorCheck(p[i]); + const int k = trans ? tensor->data.shape[0] : tensor->data.shape[1]; + + if ((k & 127) != 0) return false; + + if (ref_k < 0) + ref_k = k; + else if (k != ref_k) + return false; + } + + return true; + }; + + auto is_supported_dtype = [&]() -> bool { + auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); + auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); + auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); + auto A_type = get_cuda_dtype(inputA->data.dtype); + auto B_type = get_cuda_dtype(inputB->data.dtype); + auto D_type = get_cuda_dtype(OutputD->data.dtype); + + return (A_type == B_type) && (A_type == D_type) && + ((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F)); + }; + + // CUTLASS Grouped GEMM fast path (SM90/TMA) + // Conditions: + // - No fused epilogue: both bias and pre_gelu_out are empty. + // - Supported dtypes only: FP16/BF16 (FP32 accumulate). + // - Uniform K across groups and K % 128 == 0. + // - use_split_accumulator is ignored for FP16/BF16. + // - grad is irrelevant when bias/pre_gelu_out are empty. + // + // Otherwise, fall back to cuBLAS. + if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && + all_groups_uniform_k128(B, transb)) { + cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate, + current_device, math_sm_count, stream); + } else { + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } + cublas_path(); + } +#endif // __HIP_PLATFORM_AMD__ +} diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cu b/transformer_engine/common/gemm/cutlass_grouped_gemm.cu new file mode 100644 index 000000000..18736c4f5 --- /dev/null +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cu @@ -0,0 +1,77 @@ +/*************************************************************************************************** + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + **************************************************************************************************/ + +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass_grouped_gemm.cuh" + +namespace transformer_engine { +namespace grouped_gemm { + +// Explicit template instantiation to match the template declarations in the .cuh +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, const NVTETensor*, + NVTETensor*, NVTETensor*, float, + float, int, cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, const NVTETensor*, + NVTETensor*, NVTETensor*, float, + float, int, cudaStream_t, int, int); + +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); + +} // namespace grouped_gemm +} // namespace transformer_engine + +void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, int device, int math_sm_count, cudaStream_t stream) { + using namespace transformer_engine; + auto* inputA = convertNVTETensorCheck(A[0]); + auto* inputB = convertNVTETensorCheck(B[0]); + + float one = 1.0; + float zero = 0.0; + float alpha = one; + float beta = (accumulate) ? one : zero; + + auto dispatch = [&](auto tag) { + using T = decltype(tag); + if (!transa && !transb) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else if (!transb && transa) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else if (transb && !transa) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else { + NVTE_ERROR("Layout 'TT' is not supported by cutlass_grouped_gemm."); + } + }; + + if (inputA->data.dtype == DType::kBFloat16) { + dispatch(cutlass::bfloat16_t{}); + } else if (inputA->data.dtype == DType::kFloat16) { + dispatch(cutlass::half_t{}); + } else { + NVTE_ERROR("Unsupported dtype: only BF16(FP16) are supported."); + } +} diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh new file mode 100644 index 000000000..1add57132 --- /dev/null +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + **************************************************************************************************/ + +// +// Copyright (c) 2025 Shopee Inc. All Rights Reserved. +// + +/** + * @file: cutlass_grouped_gemm.cuh + * @author: min.yang@shopee.com, yangfan.bai@shopee.com, finch.li@shopee.com + * @date: 2025-08-08 16:20:00 + * @brief: cutlass group gemm kernel. + **/ + +#pragma once + +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "common/util/system.h" +#include "cute/tensor.hpp" +#include "cutlass/bfloat16.h" +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +template +using GroupedGemmInputALayout = + std::conditional_t; + +template +using GroupedGemmInputBLayout = + std::conditional_t; + +using ProblemShapeType = cute::Shape; +using ProblemShape = cutlass::gemm::GroupProblemShape; // per group +template +struct GemmGivenSchedule { + using ElementA = typename ScheduleConfig::DataType; // Element type for A matrix operand + using ElementB = typename ScheduleConfig::DataType; // Element type for B matrix operand + using ElementC = typename ScheduleConfig::DataType; // Element type for C and D matrix operands + + // A matrix configuration + using LayoutA = typename ScheduleConfig::LayoutA; // Layout type for A matrix operand + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits< + ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using LayoutB = typename ScheduleConfig::LayoutB; // Layout type for B matrix operand + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using LayoutC = typename ScheduleConfig::LayoutC; // Layout type for C and D matrix operands + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes) + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = + cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + + using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size + using ClusterShape = + typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster + using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +struct ScheduleConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + // TODO(Alan): Add tuning for different scenarios to select the optimal configuration, + // as the current configuration may not be the best. + + // using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + // using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + // using TileShape = Shape; + // using ClusterShape = Shape; + + using LayoutA = GroupedGemmInputALayout; + using LayoutB = GroupedGemmInputBLayout; + using LayoutC = cutlass::layout::RowMajor; + using DataType = DataType_; +}; + +template +using GemmGrouped = typename GemmGivenSchedule>::Gemm; + +template +typename GemmT::Arguments MakeArguments(int num_experts, void* problem_sizes_host, + void* problem_sizes, const ElementA** ptr_A, + StrideA* stride_A, const ElementB** ptr_B, + StrideB* stride_B, ElementC** ptr_C, StrideC* stride_C, + float alpha, float beta, int device, int math_sm_count) { + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + + cutlass::KernelHardwareInfo kernel_hw_info = + cutlass::KernelHardwareInfo::make_kernel_hardware_info( + device, math_sm_count); + + typename GemmT::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + + fusion_args.alpha = alpha; + fusion_args.beta = beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + + arguments = + typename GemmT::Arguments{cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, reinterpret_cast(problem_sizes), + reinterpret_cast(problem_sizes_host)}, + {ptr_A, stride_A, ptr_B, stride_B}, + { + fusion_args, + (beta > 0.0) ? (const ElementC**)ptr_C : nullptr, // NOLINT(*) + stride_C, + ptr_C, + stride_C, + }, + kernel_hw_info}; + + return arguments; +} + +template +inline __device__ __host__ T ROUND_UP(T m, T n) { + return (m + n - 1) / n * n; +} + +template +void debug_type() { + std::cout << typeid(T).name() << std::endl; +} + +int64_t inline getGemmCoordSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(ProblemShapeType), 128UL)); +} + +int64_t inline getPtrSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(half*), 128UL)); +} + +int64_t inline getLddSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(int64_t), 128UL)); +} + +// cpu workspace size is 4MB +static constexpr size_t kCPUWorkSpaceSize = 4 * 1024 * 1024; + +static char* getHostWorkspace() { + static std::once_flag flag; + static std::shared_ptr workspace; + + std::call_once(flag, [&]() { + workspace = + std::shared_ptr(reinterpret_cast(std::malloc(kCPUWorkSpaceSize)), [](char* p) { + if (p) std::free(p); + }); + + if (!workspace) { + throw std::bad_alloc(); + } + }); + + return workspace.get(); +} + +template +void CutlassGroupedGemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + NVTETensor* workspace, float alpha, float beta, int num_gemms, + cudaStream_t stream, int device, int math_sm_count) { + using Gemm = GemmGrouped; + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + + typename Gemm::Arguments arguments; + size_t kernel_workspace_size = Gemm::get_workspace_size(arguments); + auto gemm_coord_size = getGemmCoordSize(num_gemms); + auto ptr_size = getPtrSize(num_gemms); + auto ldd_size = getLddSize(num_gemms); + auto param_workspace_size = 3 * ptr_size + 3 * ldd_size + gemm_coord_size; + + NVTE_CHECK( + param_workspace_size < kCPUWorkSpaceSize, + "Insufficient kCPUWorkSpaceSize size: required=", static_cast(param_workspace_size), + ", available=", static_cast(kCPUWorkSpaceSize), " for CUTLASS grouped GEMM."); + + auto total_workspace_size = param_workspace_size + kernel_workspace_size; + transformer_engine::Tensor* wspace = transformer_engine::convertNVTETensor(workspace[0]); + + NVTE_CHECK(total_workspace_size < wspace->numel(), "Insufficient workspace[0] size: required=", + static_cast(total_workspace_size), + ", available=", static_cast(wspace->numel()), " for CUTLASS grouped GEMM."); + + char* workspace_ptr = reinterpret_cast(wspace->data.dptr); + + char* kernel_workspace_ptr = nullptr; + + char* host_workspace = getHostWorkspace(); + + ProblemShapeType* problem_sizes_host = reinterpret_cast(host_workspace); + + ElementA** ptr_A_host = reinterpret_cast(host_workspace + gemm_coord_size); + ElementB** ptr_B_host = reinterpret_cast(host_workspace + gemm_coord_size + ptr_size); + ElementC** ptr_C_host = + reinterpret_cast(host_workspace + gemm_coord_size + 2 * ptr_size); + int64_t* lda_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 0 * ldd_size); + int64_t* ldb_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 1 * ldd_size); + int64_t* ldc_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 2 * ldd_size); + + for (size_t i = 0; i < num_gemms; i++) { + const transformer_engine::Tensor* inputA = transformer_engine::convertNVTETensorCheck(A[i]); + const transformer_engine::Tensor* inputB = transformer_engine::convertNVTETensorCheck(B[i]); + transformer_engine::Tensor* outputD = transformer_engine::convertNVTETensor(D[i]); + + const int m = trans_a ? inputA->data.shape[1] : inputA->data.shape[0]; + const int k = trans_a ? inputA->data.shape[0] : inputA->data.shape[1]; + const int n = trans_b ? inputB->data.shape[0] : inputB->data.shape[1]; + + auto problem = ProblemShapeType(m, n, k); + problem_sizes_host[i] = problem; + + ptr_A_host[i] = reinterpret_cast(inputA->data.dptr); + ptr_B_host[i] = reinterpret_cast(inputB->data.dptr); + ptr_C_host[i] = reinterpret_cast(outputD->data.dptr); + + lda_host[i] = LayoutA::packed({m, k}).stride(0); + ldb_host[i] = LayoutB::packed({k, n}).stride(0); + ldc_host[i] = LayoutC::packed({m, n}).stride(0); + } + + cudaMemcpyAsync(workspace_ptr, host_workspace, param_workspace_size, cudaMemcpyHostToDevice, + stream); + + char* param_workspace_ptr = workspace_ptr; + ProblemShapeType* problem_sizes_device = reinterpret_cast(param_workspace_ptr); + const ElementA** ptr_A = reinterpret_cast( + reinterpret_cast(param_workspace_ptr) + gemm_coord_size); + const ElementB** ptr_B = reinterpret_cast( + reinterpret_cast(param_workspace_ptr) + gemm_coord_size + 1 * ptr_size); + ElementC** ptr_C = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 2 * ptr_size); + + StrideA* lda = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 0 * ldd_size); + StrideB* ldb = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 1 * ldd_size); + StrideC* ldc = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 2 * ldd_size); + + kernel_workspace_ptr = workspace_ptr + param_workspace_size; + + arguments = MakeArguments( + num_gemms, problem_sizes_host, problem_sizes_device, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + alpha, beta, device, math_sm_count); + + Gemm gemm; + + // Check can implement the kernel. + if (gemm.can_implement(arguments) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to implement CUTLASS Grouped GEMM"); + } + + // Initialize the kernel. + if (gemm.initialize(arguments, kernel_workspace_ptr) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM"); + } + + // Execute the kernel in the current stream. + if (gemm.run(stream) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to run CUTLASS Grouped GEMM"); + } +} + +} // namespace grouped_gemm +} // namespace transformer_engine + +void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, int device, int math_sm_count, cudaStream_t stream); diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 94f1bbfbd..fef3966a5 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -939,13 +939,9 @@ void hipblaslt_gemm(const Tensor *inputA, bool grad, void* workspace, size_t workspaceSize, - bool accumulate, + float alpha, float beta, bool use_split_accumulator, int math_sm_count, - int m_split, - int n_split, - bool gemm_producer, - const Tensor *inputCounter, hipStream_t stream, hipblasLtHandle_t handle ) { @@ -979,7 +975,7 @@ void hipblaslt_gemm(const Tensor *inputA, << " gelu=" << (outputPreGelu->data.dptr != nullptr) << " use_fp8=" << use_fp8 << " scale_mode=" << (a_tensor ? "tensor" : a_block ? "mxfp8" : "unsupported") - << " accumulate=" << accumulate + << " alpha=" << alpha << " beta=" << beta << std::endl; } @@ -1026,13 +1022,6 @@ void hipblaslt_gemm(const Tensor *inputA, NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!"); } #endif - if (is_fp8_dtype(outputD->data.dtype)) { - NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!"); - } - - float one = 1.0; - float zero = 0.0; - float beta = (accumulate) ? one : zero; int device_id; NVTE_CHECK_CUDA(hipGetDevice(&device_id)); @@ -1204,7 +1193,7 @@ void hipblaslt_gemm(const Tensor *inputA, if (HIPBLAS_STATUS_SUCCESS == hipblaslt_ext::matmulIsAlgoSupported( handle, operationDesc, - static_cast(&one), + static_cast(&alpha), Adesc, Bdesc, static_cast(&beta), @@ -1284,7 +1273,7 @@ void hipblaslt_gemm(const Tensor *inputA, // Warm-up call NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&one), /* alpha */ + static_cast(&alpha), /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ @@ -1306,7 +1295,7 @@ void hipblaslt_gemm(const Tensor *inputA, { NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&one), /* alpha */ + static_cast(&alpha), /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ @@ -1367,7 +1356,7 @@ void hipblaslt_gemm(const Tensor *inputA, // D = alpha * (A * B) + beta * C NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&one), /* alpha */ + static_cast(&alpha), /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ @@ -1510,10 +1499,12 @@ void release_service_stream(hipStream_t stream, struct ServiceStreamCtl &ctl) void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, bool transa, bool transb, bool grad, - void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, - int math_sm_count, int m_split, int n_split, bool gemm_producer, - const Tensor *inputCounter, hipStream_t stream, int compute_stream_offset) + const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, + cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, + float alpha, float beta, bool use_split_accumulator, int math_sm_count, + [[maybe_unused]] int m_split, [[maybe_unused]] int n_split, + [[maybe_unused]] bool gemm_producer, [[maybe_unused]] const Tensor *inputCounter, + hipStream_t stream, int compute_stream_offset) { // Tensor dims in row-major order const int A0 = inputA->flat_first_dim(); @@ -1521,19 +1512,21 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const int B0 = inputB->flat_first_dim(); const int B1 = inputB->flat_last_dim(); + const bool is_transa = transa == CUBLAS_OP_T; + const bool is_transb = transb == CUBLAS_OP_T; + // GEMM dims in column-major order - const int m = transa ? A0 : A1; - const int n = transb ? B1 : B0; - const int k = transa ? A1 : A0; - NVTE_CHECK((transb ? B0 : B1) == k, + const int m = is_transa ? A0 : A1; + const int n = is_transb ? B1 : B0; + const int k = is_transa ? A1 : A0; + NVTE_CHECK((is_transb ? B0 : B1) == k, "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, ")"); - const int lda = transa ? k : m; - const int ldb = transb ? n : k; + const int lda = is_transa ? k : m; + const int ldb = is_transb ? n : k; const int ldd = m; - ServiceStreamCtl ss_ctl; bool use_service_stream = (math_sm_count != 0) ? get_service_stream(math_sm_count, stream, ss_ctl) : false; @@ -1551,14 +1544,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, handle = hipblaslt_handles[compute_stream_offset]; } - hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, - m, n, k, lda, ldb, ldd, - (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, - (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N, - grad, - workspace, workspaceSize, accumulate, use_split_accumulator, - math_sm_count, m_split, n_split, gemm_producer, - inputCounter, use_service_stream ? ss_ctl.stream : stream, handle); + hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd, transa, + transb, grad, workspace, workspaceSize, alpha, beta, use_split_accumulator, + math_sm_count, use_service_stream ? ss_ctl.stream : stream, handle); if (use_service_stream) { diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm.h b/transformer_engine/common/include/transformer_engine/comm_gemm.h new file mode 100644 index 000000000..14cf56a00 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_gemm.h @@ -0,0 +1,156 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file comm_gemm.h + * \brief Functions for distributed (multi-GPU) matrix multiplication. + * + * This API is a TE-native binding to cuBLASMp library. + * Refer here: https://docs.nvidia.com/cuda/cublasmp/usage/tp.html for specific + * patterns, which allow communication-computation overlap. + * + * All GEMM functions here have the same computation semantic, as expressed + * on global matrices, similar to nvte_cublas_gemm call: + * - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors + * - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty + * - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors + * + * Functions differ in matrix distribution patterns + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_ +#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_ + +#include +#include + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +typedef struct NVTECommGemmCtx NVTECommGemmCtx; + +enum NVTECommGemmAlgoType { + kNVTECommGemmAlgoDefault = 0, + kNVTECommGemmAlgoSplitP2P = 1, + kNVTECommGemmAlgoSplitMulticast = 2, + kNVTECommGemmAlgoAtomicP2P = 3, + kNVTECommGemmAlgoAtomicMulticast = 4 +}; + +/*! \brief Create a comm-gemm context. + * + * \param[in] comm NCCL communicator. + * \param[in] nranks Number of ranks. + * \param[in] rank Local rank. + */ +NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank); + +/*! \brief Destroy a comm-gemm context. + * + * \param[in] ctx Context to destroy. + */ +void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx); + +/*! \brief Perform AllGather communication followed by GEMM + * + * Gathers distributed data from all ranks, then computes matrix multiplication. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] m Global m dimension. + * \param[in] n Global n dimension. + * \param[in] k Global k dimension. + * \param[in] a Local part of A matrix. + * \param[in] b Local part of B matrix. + * \param[in,out] d Local part of D matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_act_out Local part of output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of gradient computation. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics) + * \param[in] main_stream CUDA stream used for computation. + * \param[in] algo Algorithm to use. + */ +void nvte_all_gather_gemm(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo); + +/*! \brief Perform GEMM followed by ReduceScatter communication + * + * Computes matrix multiplication, then distributes results across ranks with reduction. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] m Global m dimension. + * \param[in] n Global n dimension. + * \param[in] k Global k dimension. + * \param[in] a Local part of A matrix. + * \param[in] b Local part of B matrix. + * \param[in,out] d Local part of D matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_act_out Local part of output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of gradient computation. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics) + * \param[in] main_stream CUDA stream used for computation. + * \param[in] algo Algorithm to use. + */ +void nvte_gemm_reduce_scatter(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, + const NVTETensor a, const NVTETensor b, const NVTETensor d, + const NVTETensor bias, const NVTETensor pre_act_out, bool transa, + bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t main_stream, NVTECommGemmAlgoType algo); + +/*! \brief Perform GEMM followed by AllReduce communication + * + * Computes matrix multiplication, then reduces results across all ranks. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] m Global m dimension. + * \param[in] n Global n dimension. + * \param[in] k Global k dimension. + * \param[in] a Local part of A matrix. + * \param[in] b Local part of B matrix. + * \param[in,out] d Local part of D matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_act_out Local part of output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of gradient computation. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics) + * \param[in] main_stream CUDA stream used for computation. + * \param[in] algo Algorithm to use. + */ +void nvte_gemm_all_reduce(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo); + +/*! \brief Get local number of rows or columns. + * + * Utility function to get local dimension. + * Block size, nranks and local rank is derived from the context ctx. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] global_size Global dimension. + */ +int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_COMM_GEMM_H_ diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 293c57526..4d65e26ce 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -36,7 +36,8 @@ enum class CommOverlapAlgo { SPLIT_PIPELINED_RS_P2P = 4, ATOMIC_GEMM_RS = 5, ATOMIC_GEMM_AG_P2P = 6, - ATOMIC_GEMM_RS_P2P = 7 + ATOMIC_GEMM_RS_P2P = 7, + EXTERNAL_BULK_OVERLAP_AG = 8, }; class CommOverlapCore { @@ -133,6 +134,11 @@ class CommOverlapCore { cudaStream_t stream_main) { NVTE_ERROR("Operation is not implemented."); } + + virtual void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } }; // CommOverlapCore class CommOverlapBase : public CommOverlapCore { @@ -198,6 +204,9 @@ class CommOverlapBase : public CommOverlapCore { TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) override; + + void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) override; }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { @@ -277,6 +286,15 @@ class CommOverlapP2PBase : public CommOverlapCore { TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) override; + + /* + ** This function overlaps the AG for the current communicator object with the GEMM for the overlap_gemm object. + ** The gemm for overlap_gemm is assumed to have been previously started. + */ + void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } }; // CommOverlapP2PBase } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/dropout.h b/transformer_engine/common/include/transformer_engine/dropout.h new file mode 100644 index 000000000..6ba1ab912 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/dropout.h @@ -0,0 +1,51 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dropout.h + * \brief Functions for dropout. + */ + +#ifndef TRANSFORMER_ENGINE_DROPOUT_FP8_H_ +#define TRANSFORMER_ENGINE_DROPOUT_FP8_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Dropout forward kernel. + * + * \param[in] input Input tensor. + * \param[out] output Output tensor. + * \param[out] mask Mask tensor. Each bit corresponds to an + * output tensor entry. Ones indicate kept + * entries and zeros indicate dropped entries. + * \param[in] rng_state RNG engine inputs. + * \param[in] dropout_probability Dropout probability. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_dropout_fwd(const NVTETensor input, NVTETensor output, NVTETensor mask, + NVTETensor rng_state, float dropout_probability, cudaStream_t stream); + +/*! \brief Dropout backward kernel. + * + * \param[in] grad_output Gradient of output tensor. + * \param[out] mask Mask tensor. Each bit corresponds to an + * output tensor entry. Ones indicate kept + * entries and zeros indicate dropped entries. + * \param[out] grad_input Gradient of input tensor. + * \param[in] dropout_probability Dropout probability. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_dropout_bwd(const NVTETensor grad_output, const NVTETensor mask, NVTETensor grad_input, + float dropout_probability, cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index f0817a97f..610868f93 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -75,6 +75,69 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream); +/*! \brief Apply rotary positional embedding to the combined QKV input tensor. + * + * \param[in] qkv_input Combined QKV input tensor for fused rope. + * \param[in] q_freqs The freqs tensor for Q. + * \param[in] k_freqs The freqs tensor for K. + * \param[in] start_positions The beginning offsets for applying RoPE embeddings. + * \param[out] q_out Output tensor for Q. + * \param[out] k_out Output tensor for K. + * \param[out] v_out Output tensor for V. + * \param[in] qkv_format QKV format. + * \param[in] interleaved Whether to use interleaved rotary position embedding. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. + * \param[in] s Length of the s dimension of input. + * \param[in] b Length of the b dimension of input. + * \param[in] h Length of the h dimension of input. + * \param[in] d Length of the d dimension of input. + * \param[in] d2 Length of the d dimension of freqs. + * \param[in] qkv_split_arg_list_0 The hidden size for Q. + * \param[in] qkv_split_arg_list_1 The hidden size for K. + * \param[in] qkv_split_arg_list_2 The hidden size for V. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, + const NVTETensor k_freqs, const NVTETensor start_positions, + NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream); + +/*! \brief Compute the backward of the fused qkv rope. + * + * \param[in] q_grad_out Incoming gradient tensor for Q. + * \param[in] k_grad_out Incoming gradient tensor for K. + * \param[in] v_grad_out Incoming gradient tensor for V. + * \param[in] q_freqs The freqs tensor for Q. + * \param[in] k_freqs The freqs tensor for K. + * \param[out] qkv_grad_input Input gradient tensor to calculate. + * \param[in] qkv_format QKV format. + * \param[in] interleaved Whether to use interleaved rotary position embedding. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. + * \param[in] s Length of the s dimension of input. + * \param[in] b Length of the b dimension of input. + * \param[in] h Length of the h dimension of input. + * \param[in] d Length of the d dimension of input. + * \param[in] d2 Length of the d dimension of freqs. + * \param[in] qkv_split_arg_list_0 The hidden size for Q. + * \param[in] qkv_split_arg_list_1 The hidden size for K. + * \param[in] qkv_split_arg_list_2 The hidden size for V. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out, + const NVTETensor v_grad_out, const NVTETensor q_freqs, + const NVTETensor k_freqs, NVTETensor qkv_grad_input, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index a68070308..58c0a1f96 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -46,6 +46,36 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); +/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations, + * allowing for using a scaling factor for the GEMM result and the accumulation input + * + * Computes: + * - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors + * - `D = alpha*AB + bias` if `pre_gelu_out` is empty and `bias` is not empty + * - `D = GELU(alpha*AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors + * + * \param[in] A The A matrix. + * \param[in] B The B matrix. + * \param[in,out] D Output matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_gelu_out Output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of the + * gradient computation. + * \param[out] workspace Workspace tensor. + * \param[in] alpha Scaling factor applied to the result of the GEMM + * \param[in] beta Scaling factor applied to original value of D when + * accumulating into it. beta=0 means no accumulation. + * \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM. + * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D, + const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, + bool transb, bool grad, NVTETensor workspace, float alpha, float beta, + bool use_split_accumulator, int math_sm_count, cudaStream_t stream); + /*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters. * * \warning Cublas atomic gemm uses a beta API and is not tested for all use cases. @@ -105,12 +135,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] stream CUDA stream to wait on. */ -void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, - const NVTETensor* bias, NVTETensor* pre_gelu_out, - const int num_gemms, bool transa, bool transb, bool grad, - NVTETensor* workspace, bool accumulate, - bool use_split_accumulator, int math_sm_count, - cudaStream_t stream); +void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, bool use_split_accumulator, int math_sm_count, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index c21fd2627..a01b2e5da 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -20,7 +20,6 @@ extern "C" { /*! \brief Computes L2 norm for a list of tensors. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -33,22 +32,19 @@ extern "C" { * \param[out] ret_per_tensor L2 norm for each tensor. * \param[in] per_tensor Whether to calculate per tensor or cumulative norm. * \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret, NVTETensor ret_per_tensor, int per_tensor, - int max_chunks_per_tensor, const int device_id, - cudaStream_t stream); + int max_chunks_per_tensor, cudaStream_t stream); /*! \brief Computes L2 norm for a list of tensors after unscaling. * * Unscaling is only done for computing the L2 norm. The tensors themselves are not updated. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -62,7 +58,6 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen * \param[in] inv_scale Scalar for the unscaling operation. * \param[in] per_tensor Whether to calculate per tensor or cumulative norm. * \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, @@ -71,12 +66,11 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor output_per_tensor, NVTETensor ret, NVTETensor ret_per_tensor, NVTETensor inv_scale, int per_tensor, int max_chunks_per_tensor, - const int device_id, cudaStream_t stream); + cudaStream_t stream); /*! \brief Compute and apply gradient update to parameters for Adam optimizer. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -91,7 +85,6 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, * \param[in] mode Whether to use AdamW (L2 penalty applied to params). * \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] weight_decay L2 penalty for weight decay. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, @@ -99,13 +92,12 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, - const int device_id, cudaStream_t stream); + cudaStream_t stream); /*! \brief Compute and apply gradient update to parameters for Adam optimizer * where the master parameters only store the remainder bits. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -120,20 +112,18 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso * \param[in] mode Whether to use AdamW (L2 penalty applied to params). * \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] weight_decay L2 penalty for weight decay. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_adam_param_remainder_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, - const float weight_decay, const int device_id, cudaStream_t stream); + const float weight_decay, cudaStream_t stream); /*! \brief Compute and apply gradient update to parameters for Adam optimizer * when model parameters are in Float8 precision. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -149,7 +139,6 @@ void nvte_multi_tensor_adam_param_remainder_cuda( * \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] weight_decay L2 penalty for weight decay. * \param[in] fp8_dtype FP8 data type for model parameters. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, @@ -158,13 +147,12 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, const NVTEDType fp8_dtype, - const int device_id, cudaStream_t stream); + cudaStream_t stream); /*! \brief Compute and apply gradient update to parameters for Adam optimizer * with CUDA graph support and LR scheduling. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -180,20 +168,18 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, * \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] weight_decay L2 penalty for weight decay. * \param[in] inv_scale Scalar for the unscaling operation. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_adam_capturable_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const float epsilon, NVTETensor step, const int mode, const int bias_correction, - const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); + const float weight_decay, NVTETensor inv_scale, cudaStream_t stream); /*! \brief Compute and apply gradient update to parameters for Adam optimizer * with CUDA graph support, LR scheduling, and FP32 master weights. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -209,19 +195,17 @@ void nvte_multi_tensor_adam_capturable_cuda( * \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] weight_decay L2 penalty for weight decay. * \param[in] inv_scale Scalar for the unscaling operation. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_adam_capturable_master_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const float epsilon, NVTETensor step, const int mode, const int bias_correction, - const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); + const float weight_decay, NVTETensor inv_scale, cudaStream_t stream); /*! \brief Compute and apply gradient update to parameters for SGD optimizer. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -236,19 +220,17 @@ void nvte_multi_tensor_adam_capturable_master_cuda( * \param[in] first_run Whether momentum buffers have been initialized. * \param[in] wd_after_momentum Whether to applied weight decay after momentum update. * \param[in] scale Scalar for the scaling operation. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, float wd, float momentum, float dampening, float lr, int nesterov, int first_run, int wd_after_momentum, float scale, - const int device_id, cudaStream_t stream); + cudaStream_t stream); /*! \brief Check overflow and scale a list of tensors. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -256,17 +238,15 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor * \param[in] num_tensor_lists Size (dim0) of tensor_lists. * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. * \param[in] scale Scalar for the scaling operation. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, - float scale, const int device_id, cudaStream_t stream); + float scale, cudaStream_t stream); /*! \brief Check overflow and scale a list of tensors. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -276,13 +256,14 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens * \param[in] max_fp8 Maximum representible value in underlying FP8 format. * \param[in] force_pow_2_scales Ensure scaling factors are a power of 2. * \param[in] epsilon Term added to the denominator for numerical stability. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ -void nvte_multi_tensor_compute_scale_and_scale_inv_cuda( - int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, - const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon, - const int device_id, cudaStream_t stream); +void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, + const size_t num_tensor_lists, + const size_t num_tensors_per_list, + float max_fp8, int force_pow_2_scales, + float epsilon, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/normalization.h b/transformer_engine/common/include/transformer_engine/normalization.h index cf6b91b96..61d09f0fb 100644 --- a/transformer_engine/common/include/transformer_engine/normalization.h +++ b/transformer_engine/common/include/transformer_engine/normalization.h @@ -26,7 +26,7 @@ extern "C" { * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta * @f] * - * Calling this function with workspace set to empty tensor will not perform the operation, + * Calling this function with workspace set to an empty tensor will not perform the operation, * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] x Input tensor of shape [N, H]. @@ -57,8 +57,8 @@ void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETe * else * with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$. * - * Calling this function with workspace set to empty tensor will not perform the operation, - * but instead set the shape and type of these tensors to the required values. + * Calling this function with workspace set to an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] dz Incoming gradient tensor of shape [N, H]. * \param[in] x Forward input tensor of shape [N, H]. @@ -92,9 +92,8 @@ void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETenso * RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon} * @f] * - * Calling this function with workspace and barrier set to empty tensor will not - * perform the operation, but instead set the shape and type of the workspace - * and barrier tensors to the required values. + * Calling this function with workspace set to an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] x Input tensor of shape [N, H]. * \param[in] gamma Gamma tensor of shape [H]. @@ -123,9 +122,8 @@ void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float ep * @f] * with respect to \f$x\f$ and \f$gamma\f$. * - * Calling this function with workspace, barrier, dgamma_part set - * to empty tensor will not perform the operation, but instead set the shape and type - * of these tensors to the required values. + * Calling this function with workspace set to an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] dz Incoming gradient tensor of shape [N, H]. * \param[in] x Forward input tensor of shape [N, H]. @@ -144,6 +142,29 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor NVTETensor workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream); +/*! \brief Compute backward of RMSNorm and add additional tensor to output gradient + * + * Calling this function with workspace set to an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] dz Incoming gradient tensor of shape [N, H]. + * \param[in] x Forward input tensor of shape [N, H]. + * \param[in] add Additional tensor to add to output gradient [N, H]. + * \param[in] rsigma Reciprocal of the root mean square of the input + * calculated over the last dimension. Shape: [N]. + * \param[in] gamma Gamma tensor of shape [H]. + * \param[out] dx Output gradient of shape [N, H]. + * \param[out] dgamma Gradient for gamma tensor of shape [H]. + * \param[out] workspace Workspace tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_rmsnorm_bwd_add(const NVTETensor dz, const NVTETensor x, const NVTETensor add, + const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx, + NVTETensor dgamma, NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream); + #ifndef __HIP_PLATFORM_AMD__ /*! \brief Helper to enable cuDNN backend for normalization * diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 222c94225..89515108a 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -92,6 +92,21 @@ constexpr int amax_kernel_threads = 512; */ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Compute an FP8 tensor's amax with quantization config. + * + * The amax (maximum absolute value) of the input tensor is computed + * and written to the amax buffer of the output tensor, using the provided + * quantization configuration. + * One useful config is the noop tensor, which is needed by cuda graph. + * + * \param[in] input Input tensor. Must be unquantized. + * \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling. + * \param[in] config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig config, cudaStream_t stream); + #ifdef __HIP_PLATFORM_AMD__ size_t nvte_amax_workspace_num_blocks(size_t N); @@ -104,9 +119,12 @@ size_t nvte_amax_workspace_num_blocks(size_t N); * \param[in] input Input tensor. Must be unquantized. * \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling. * \param[out] workspace Output tensor. Must be FP32. + * \param[in] config Quantization configuration. * \param[in] stream CUDA stream used for the operation. */ -void nvte_compute_amax_with_workspace(const NVTETensor input, NVTETensor output, NVTETensor workspace, cudaStream_t stream); +void nvte_compute_amax_with_workspace(const NVTETensor input, NVTETensor output, + NVTETensor workspace, const NVTEQuantizationConfig config, + cudaStream_t stream); #endif diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index de5a11eb7..079feb4a7 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -30,6 +30,20 @@ extern "C" { */ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM + * + * \param[in] inputs Input tensors with non-swizzled scale_inv. + * \param[in,out] outputs Output tensors which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scale_inv is stored in row-major. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + */ +void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, + const size_t num_tensors, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index a7db5cba4..cc069ee3e 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -318,6 +318,14 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_in void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, NVTETensor output, cudaStream_t stream); +/*! \brief Swap the first two tensor dimensions. + * + * \param[in] input Input tensor of shape [M, N, ...]. + * \param[out] output Output tensor of shape [N, M, ...]. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_swap_first_dims(const NVTETensor input, NVTETensor output, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index d395e1f3a..dbbfd64b8 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -8,6 +8,7 @@ transformer_engine::cuda::stream_priority_range*; transformer_engine::cuda::current_device*; transformer_engine::cuda_driver::get_symbol*; + transformer_engine::cuda_driver::ensure_context_exists*; transformer_engine::ubuf_built_with_mpi*; *transformer_engine::rtc*; transformer_engine::nvte_cudnn_handle_init*; diff --git a/transformer_engine/common/multi_tensor/adam.cu b/transformer_engine/common/multi_tensor/adam.cu index 72f150512..07a66176b 100644 --- a/transformer_engine/common/multi_tensor/adam.cu +++ b/transformer_engine/common/multi_tensor/adam.cu @@ -583,7 +583,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, - const float weight_decay, const int device_id, cudaStream_t stream) { + const float weight_decay, cudaStream_t stream) { // Handle bias correction mode float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { @@ -650,20 +650,20 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, g_in_type_te, g_in_type, multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctor(), device_id, - stream, beta1, beta2, bias_correction1, bias_correction2, - epsilon, lr, (adamMode_t)mode, weight_decay);)); + AdamFunctor(), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); } else { // g, p, m, v, p_master TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<5>( - (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), device_id, stream, - beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);)); + multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, + tensor_lists, + AdamFunctorMaster(), + stream, beta1, beta2, bias_correction1, bias_correction2, + epsilon, lr, (adamMode_t)mode, weight_decay);)); } } else { if (num_tensor_lists == 4) { @@ -673,9 +673,9 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctor(), device_id, - stream, beta1, beta2, bias_correction1, bias_correction2, - epsilon, lr, (adamMode_t)mode, weight_decay);)); + AdamFunctor(), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); } else { // g, p, m, v, p_master TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( @@ -684,9 +684,8 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, g_in_type_te, g_in_type, multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdamFunctorMaster(), - device_id, stream, beta1, beta2, bias_correction1, - bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);)); + stream, beta1, beta2, bias_correction1, bias_correction2, + epsilon, lr, (adamMode_t)mode, weight_decay);)); } } NVTE_CHECK_CUDA(cudaGetLastError()); @@ -697,7 +696,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { // Handle bias correction mode float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { @@ -739,8 +738,8 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctorMasterParamRemainder(), device_id, - stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + AdamFunctorMasterParamRemainder(), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -750,7 +749,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, const DType fp8_dtype, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { // Handle bias correction mode float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { @@ -820,9 +819,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, g_in_type_te, g_in_type, multi_tensor_apply<5, true>( (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), device_id, stream, beta1, - beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);)); + AdamFunctorMaster(), stream, beta1, beta2, + bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( fp8_dtype, FP8_T, @@ -830,9 +828,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, g_in_type_te, g_in_type, multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdamFunctorMaster(), - device_id, stream, beta1, beta2, bias_correction1, - bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);)); + stream, beta1, beta2, bias_correction1, bias_correction2, + epsilon, lr, (adamMode_t)mode, weight_decay);)); } NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -842,7 +839,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, const float beta1, const float beta2, const float epsilon, Tensor step, const int mode, const int bias_correction, const float weight_decay, Tensor inv_scale, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { // Check tensor list sizes // 4 tensor lists: g, p, m, v const size_t num_tensor_lists = tensor_lists.size(); @@ -874,7 +871,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamCapturableFunctor(), device_id, stream, beta1, beta2, + AdamCapturableFunctor(), stream, beta1, beta2, reinterpret_cast(step.data.dptr), bias_correction, epsilon, reinterpret_cast(lr.data.dptr), (adamMode_t)mode, weight_decay, reinterpret_cast(inv_scale.data.dptr));) @@ -887,8 +884,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, Tensor lr, const float beta1, const float beta2, const float epsilon, Tensor step, const int mode, const int bias_correction, const float weight_decay, - Tensor inv_scale, const int device_id, - cudaStream_t stream) { + Tensor inv_scale, cudaStream_t stream) { // Check tensor list sizes // 4 tensor lists: g, p, m, v, p_master const size_t num_tensor_lists = tensor_lists.size(); @@ -923,10 +919,10 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamCapturableMasterFunctor(), device_id, stream, beta1, - beta2, reinterpret_cast(step.data.dptr), bias_correction, - epsilon, reinterpret_cast(lr.data.dptr), (adamMode_t)mode, - weight_decay, reinterpret_cast(inv_scale.data.dptr));) + AdamCapturableMasterFunctor(), stream, beta1, beta2, + reinterpret_cast(step.data.dptr), bias_correction, epsilon, + reinterpret_cast(lr.data.dptr), (adamMode_t)mode, weight_decay, + reinterpret_cast(inv_scale.data.dptr));) NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -939,28 +935,28 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_adam_cuda); using namespace transformer_engine; multi_tensor_adam::multi_tensor_adam_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, - epsilon, step, mode, bias_correction, weight_decay, device_id, stream); + epsilon, step, mode, bias_correction, weight_decay, stream); } void nvte_multi_tensor_adam_param_remainder_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, - const float weight_decay, const int device_id, cudaStream_t stream) { + const float weight_decay, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_adam_param_remainder_cuda); using namespace transformer_engine; multi_tensor_adam::multi_tensor_adam_param_remainder_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, - epsilon, step, mode, bias_correction, weight_decay, device_id, stream); + epsilon, step, mode, bias_correction, weight_decay, stream); } void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, @@ -969,22 +965,21 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, const NVTEDType fp8_dtype, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_adam_fp8_cuda); using namespace transformer_engine; multi_tensor_adam::multi_tensor_adam_fp8_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, - epsilon, step, mode, bias_correction, weight_decay, static_cast(fp8_dtype), device_id, - stream); + epsilon, step, mode, bias_correction, weight_decay, static_cast(fp8_dtype), stream); } void nvte_multi_tensor_adam_capturable_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const float epsilon, NVTETensor step, const int mode, const int bias_correction, - const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) { + const float weight_decay, NVTETensor inv_scale, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_adam_capturable_cuda); using namespace transformer_engine; @@ -992,14 +987,14 @@ void nvte_multi_tensor_adam_capturable_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode, - bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream); + bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), stream); } void nvte_multi_tensor_adam_capturable_master_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const float epsilon, NVTETensor step, const int mode, const int bias_correction, - const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) { + const float weight_decay, NVTETensor inv_scale, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_adam_capturable_master_cuda); using namespace transformer_engine; @@ -1007,5 +1002,5 @@ void nvte_multi_tensor_adam_capturable_master_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode, - bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream); + bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), stream); } diff --git a/transformer_engine/common/multi_tensor/compute_scale.cu b/transformer_engine/common/multi_tensor/compute_scale.cu index ebdcfbb56..dc4eb8714 100644 --- a/transformer_engine/common/multi_tensor/compute_scale.cu +++ b/transformer_engine/common/multi_tensor/compute_scale.cu @@ -58,26 +58,27 @@ struct ComputeScaleAndScaleInvFunctor { void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, float max_fp8, bool force_pow_2_scales, - float epsilon, const int device_id, - cudaStream_t stream) { + float epsilon, cudaStream_t stream) { multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - ComputeScaleAndScaleInvFunctor(), device_id, stream, max_fp8, - force_pow_2_scales, epsilon); + ComputeScaleAndScaleInvFunctor(), stream, max_fp8, force_pow_2_scales, + epsilon); NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace multi_tensor_compute_scale } // namespace transformer_engine -void nvte_multi_tensor_compute_scale_and_scale_inv_cuda( - int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, - const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon, - const int device_id, cudaStream_t stream) { +void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, + const size_t num_tensor_lists, + const size_t num_tensors_per_list, + float max_fp8, int force_pow_2_scales, + float epsilon, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_compute_scale_and_scale_inv_cuda); using namespace transformer_engine; multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8, - force_pow_2_scales, epsilon, device_id, stream); + force_pow_2_scales, epsilon, stream); } diff --git a/transformer_engine/common/multi_tensor/l2norm.cu b/transformer_engine/common/multi_tensor/l2norm.cu index e80b04f97..d15d01f88 100644 --- a/transformer_engine/common/multi_tensor/l2norm.cu +++ b/transformer_engine/common/multi_tensor/l2norm.cu @@ -404,13 +404,12 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, Tensor output, Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor, - bool per_tensor, int max_chunks_per_tensor, const int device_id, - cudaStream_t stream) { + bool per_tensor, int max_chunks_per_tensor, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, multi_tensor_apply<1>( - BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor(), device_id, - stream, reinterpret_cast(output.data.dptr), + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor(), stream, + reinterpret_cast(output.data.dptr), per_tensor ? reinterpret_cast(output_per_tensor.data.dptr) : nullptr, per_tensor, max_chunks_per_tensor);) @@ -419,26 +418,25 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag, // This involves one more small kernel launches, but will be negligible end to end. // I could get rid of these by hacking the functor + multi tensor harness with persistence // logic, but keeping it simple for now - const OptionalCUDAGuard device_guard(device_id); cleanup<<>>( reinterpret_cast(output.data.dptr), per_tensor ? reinterpret_cast(output_per_tensor.data.dptr) : nullptr, reinterpret_cast(ret.data.dptr), per_tensor ? reinterpret_cast(ret_per_tensor.data.dptr) : nullptr, per_tensor, max_chunks_per_tensor); + NVTE_CHECK_CUDA(cudaGetLastError()); } void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, Tensor output, Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor, Tensor inv_scale, bool per_tensor, - int max_chunks_per_tensor, const int device_id, - cudaStream_t stream) { + int max_chunks_per_tensor, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, multi_tensor_apply<1>( - BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor(), device_id, - stream, reinterpret_cast(inv_scale.data.dptr), + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor(), stream, + reinterpret_cast(inv_scale.data.dptr), reinterpret_cast(output.data.dptr), per_tensor ? reinterpret_cast(output_per_tensor.data.dptr) : nullptr, per_tensor, max_chunks_per_tensor);) @@ -448,13 +446,13 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, // This involves one more small kernel launches, but will be negligible end to end. // I could get rid of these by hacking the functor + multi tensor harness with persistence // logic, but keeping it simple for now - const OptionalCUDAGuard device_guard(device_id); cleanup<<>>( reinterpret_cast(output.data.dptr), per_tensor ? reinterpret_cast(output_per_tensor.data.dptr) : nullptr, reinterpret_cast(ret.data.dptr), per_tensor ? reinterpret_cast(ret_per_tensor.data.dptr) : nullptr, per_tensor, max_chunks_per_tensor); + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace multi_tensor_l2norm @@ -464,8 +462,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret, NVTETensor ret_per_tensor, int per_tensor, - int max_chunks_per_tensor, const int device_id, - cudaStream_t stream) { + int max_chunks_per_tensor, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_l2norm_cuda); using namespace transformer_engine; @@ -474,7 +471,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), *convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor), *convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), per_tensor, - max_chunks_per_tensor, device_id, stream); + max_chunks_per_tensor, stream); } void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, @@ -483,7 +480,7 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor output_per_tensor, NVTETensor ret, NVTETensor ret_per_tensor, NVTETensor inv_scale, int per_tensor, int max_chunks_per_tensor, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_unscale_l2norm_cuda); using namespace transformer_engine; @@ -492,5 +489,5 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), *convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor), *convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), - *convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream); + *convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, stream); } diff --git a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh index 4727f3964..b78612181 100644 --- a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh +++ b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh @@ -14,53 +14,6 @@ // This header is the one-stop shop for all your multi-tensor apply needs. -// Change device if needed. -class OptionalCUDAGuard { - public: - explicit OptionalCUDAGuard(int new_device) { - if (new_device < 0) return; - - int current_device; - NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); - - if (new_device != current_device) { - NVTE_CHECK_CUDA(cudaSetDevice(new_device)); - device_changed_ = true; - prev_device_ = current_device; - } - } - - OptionalCUDAGuard(const OptionalCUDAGuard &) = delete; - OptionalCUDAGuard &operator=(const OptionalCUDAGuard &) = delete; - - OptionalCUDAGuard(OptionalCUDAGuard &&other) noexcept - : prev_device_(other.prev_device_), device_changed_(other.device_changed_) { - other.device_changed_ = false; - } - - OptionalCUDAGuard &operator=(OptionalCUDAGuard &&other) noexcept { - if (this != &other) { - if (device_changed_) { - cudaSetDevice(prev_device_); - } - prev_device_ = other.prev_device_; - device_changed_ = other.device_changed_; - other.device_changed_ = false; - } - return *this; - } - - ~OptionalCUDAGuard() { - if (device_changed_) { - cudaSetDevice(prev_device_); - } - } - - private: - int prev_device_; - bool device_changed_ = false; -}; - // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; @@ -94,7 +47,7 @@ template void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const transformer_engine::Tensor &noop_flag, std::vector> tensor_lists, - T callable, const int device_id, cudaStream_t stream, ArgTypes... args) { + T callable, cudaStream_t stream, ArgTypes... args) { const size_t num_tensor_lists = tensor_lists.size(); const size_t num_tensors_per_list = tensor_lists[0].size(); @@ -108,8 +61,6 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, TensorListMetadata tl; - const OptionalCUDAGuard device_guard(device_id); - tl.start_tensor_this_launch = 0; int loc_block_info = 0; int loc_tensor_info = 0; diff --git a/transformer_engine/common/multi_tensor/scale.cu b/transformer_engine/common/multi_tensor/scale.cu index 66a173bdb..ac457adb0 100644 --- a/transformer_engine/common/multi_tensor/scale.cu +++ b/transformer_engine/common/multi_tensor/scale.cu @@ -104,13 +104,13 @@ struct ScaleFunctor { void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, float scale, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[1][0]->dtype(), g_in_type, multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - ScaleFunctor(), device_id, stream, scale);)) + ScaleFunctor(), stream, scale);)) NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -119,12 +119,11 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, - float scale, const int device_id, cudaStream_t stream) { + float scale, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_scale_cuda); using namespace transformer_engine; multi_tensor_scale::multi_tensor_scale_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), - convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id, - stream); + convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, stream); } diff --git a/transformer_engine/common/multi_tensor/sgd.cu b/transformer_engine/common/multi_tensor/sgd.cu index 05106e46d..9235de330 100644 --- a/transformer_engine/common/multi_tensor/sgd.cu +++ b/transformer_engine/common/multi_tensor/sgd.cu @@ -127,8 +127,7 @@ struct SGDFunctor { void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, float wd, float momentum, float dampening, float lr, bool nesterov, bool first_run, - bool wd_after_momentum, float scale, const int device_id, - cudaStream_t stream) { + bool wd_after_momentum, float scale, cudaStream_t stream) { const size_t num_tensor_lists = tensor_lists.size(); const size_t num_tensors_per_list = tensor_lists[0].size(); @@ -154,29 +153,29 @@ void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag, // Case 1. fp16, fp16, fp16, No if (grad_type == DType::kFloat16 && weight_type == DType::kFloat16 && num_tensor_lists == 3) { multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<3, fp16, fp16>(), device_id, stream, wd, momentum, dampening, - lr, nesterov, first_run, wd_after_momentum, scale); + SGDFunctor<3, fp16, fp16>(), stream, wd, momentum, dampening, lr, + nesterov, first_run, wd_after_momentum, scale); } // Case 2. fp32, fp32, fp32, No else if (grad_type == DType::kFloat32 && // NOLINT(*) weight_type == DType::kFloat32 && num_tensor_lists == 3) { multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<3, float, float>(), device_id, stream, wd, momentum, dampening, - lr, nesterov, first_run, wd_after_momentum, scale); + SGDFunctor<3, float, float>(), stream, wd, momentum, dampening, lr, + nesterov, first_run, wd_after_momentum, scale); } // Case 3. fp16, fp32, fp32, Yes else if (grad_type == DType::kFloat16 && // NOLINT(*) weight_type == DType::kFloat32 && num_tensor_lists == 4) { multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<4, fp16, float>(), device_id, stream, wd, momentum, dampening, - lr, nesterov, first_run, wd_after_momentum, scale); + SGDFunctor<4, fp16, float>(), stream, wd, momentum, dampening, lr, + nesterov, first_run, wd_after_momentum, scale); } // Case 4. fp32, fp32, fp32, Yes else if (grad_type == DType::kFloat32 && // NOLINT(*) weight_type == DType::kFloat32 && num_tensor_lists == 4) { multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<4, float, float>(), device_id, stream, wd, momentum, dampening, - lr, nesterov, first_run, wd_after_momentum, scale); + SGDFunctor<4, float, float>(), stream, wd, momentum, dampening, lr, + nesterov, first_run, wd_after_momentum, scale); } else { NVTE_ERROR("Unsupported combination of weight and gradient types."); } @@ -191,12 +190,12 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor const size_t num_tensor_lists, const size_t num_tensors_per_list, float wd, float momentum, float dampening, float lr, int nesterov, int first_run, int wd_after_momentum, float scale, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_sgd_cuda); using namespace transformer_engine; multi_tensor_sgd::multi_tensor_sgd_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum, - dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream); + dampening, lr, nesterov, first_run, wd_after_momentum, scale, stream); } diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 6189be7b3..de1175a16 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -166,8 +166,8 @@ void TeNormalizationPlan::_set_workspace() { if (_launch_params.barrier_bytes > 0) { _launch_params.params.barrier = reinterpret_cast(workspace_dptr + _launch_params.workspace_bytes); - (void)cudaMemsetAsync(_launch_params.params.barrier, 0, _launch_params.barrier_bytes, - _launch_params.stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(_launch_params.params.barrier, 0, + _launch_params.barrier_bytes, _launch_params.stream)); } #ifdef __HIP_PLATFORM_AMD__ if constexpr (std::is_same_v) { @@ -192,7 +192,7 @@ void TeNormalizationPlan::_set_workspace() { template <> void TeNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, - void* dx_dptr, void* dz_dptr, + void* dx_dptr, void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, cudaStream_t stream) { NVTE_ERROR("Forward normalization should not call the backward execute function!"); @@ -202,8 +202,9 @@ template <> void TeNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, void* dz_dptr, - void* dbeta_dptr, void* dgamma_dptr, - void* workspace_dptr, cudaStream_t stream) { + void* add_dptr, void* dbeta_dptr, + void* dgamma_dptr, void* workspace_dptr, + cudaStream_t stream) { _launch_params.stream = stream; auto& kernel_params = _launch_params.params; @@ -213,6 +214,7 @@ void TeNormalizationPlan::execute(void* x_dptr, void* gamm kernel_params.rs = rsigma_dptr; kernel_params.dx = dx_dptr; kernel_params.dz = dz_dptr; + kernel_params.add = add_dptr; kernel_params.dgamma = dgamma_dptr; if (_is_layernorm) { @@ -484,8 +486,11 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, void* dz_dptr, - void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, - cudaStream_t stream) { + void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) { + // cuDNN does not currently support fused backward+add + NVTE_CHECK(add_dptr == nullptr); + // Binding data pointers to graph tensors _variant_pack = { {_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}}; diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 31d2c0b74..33618f222 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -158,6 +158,9 @@ struct BackwardKernelParams : public KernelParamsBase { // Input: gradient wrt. LN FWD output. void* dz; + // Input: extra tensor to add for fused backward+add + void* add; + // Workspace for Wgrad pre-reduction. void* dbeta_part; void* dgamma_part; @@ -169,12 +172,14 @@ struct BackwardKernelParams : public KernelParamsBase { void* dgamma; }; +using BackwardAddKernelParams = BackwardKernelParams; + #ifdef __HIP_PLATFORM_AMD__ enum class NVTE_Norm_Backend { Te }; #else enum class NVTE_Norm_Backend { Te, Cudnn }; #endif -enum class NVTE_Norm_Stage { Forward, Backward }; +enum class NVTE_Norm_Stage { Forward, Backward, BackwardAdd }; using TupleKeyType = std::tuple; struct TupleHash { @@ -257,8 +262,8 @@ class NormalizationPlanBase { cudaStream_t stream) = 0; virtual void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, - void* dx_dptr, void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, - void* workspace_dptr, cudaStream_t stream) = 0; + void* dx_dptr, void* dz_dptr, void* add_dptr, void* dbeta_dptr, + void* dgamma_dptr, void* workspace_dptr, cudaStream_t stream) = 0; private: virtual void _build() = 0; @@ -281,8 +286,8 @@ class TeNormalizationPlan : public NormalizationPlanBase { cudaStream_t stream) override; void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, - void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, - cudaStream_t stream) override; + void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) override; private: void _set_workspace(); @@ -311,8 +316,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { cudaStream_t stream) override; void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, - void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, - cudaStream_t stream) override; + void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) override; private: void _build() override; diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index 8c4cdbcc0..8c6fccfb5 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -67,6 +67,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size bool is_aligned = true; #ifndef __HIP_PLATFORM_AMD__ bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); + + if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { + NVTE_CHECK(!cudnn_backend, + "cuDNN does not currently support amax output for non quantized output"); + } #endif //__HIP_PLATFORM_AMD__ bool gamma_in_weight_dtype = false; @@ -191,7 +196,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te } else { NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); plan->execute(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, dx->data.dptr, - dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr, workspace->data.dptr, stream); + dz.data.dptr, nullptr /*add*/, dbeta->data.dptr, dgamma->data.dptr, + workspace->data.dptr, stream); } return; } diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index 09618c58d..757e8f900 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -16,16 +16,16 @@ using namespace transformer_engine::normalization; template -static void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_ln_bwd_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &ln_bwd_tuned_kernel; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_col = launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; @@ -53,13 +53,14 @@ static void launch_tuned_(LaunchParams &launch_params, if (ctas_per_row == 1) { kernel<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = reinterpret_cast(&launch_params.params); - (void)cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), - Kernel_traits::SMEM_BYTES, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), + Kernel_traits::SMEM_BYTES, stream)); } using Kernel_traits_f = @@ -70,13 +71,14 @@ static void launch_tuned_(LaunchParams &launch_params, auto kernel_f = &ln_bwd_finalize_tuned_kernel; kernel_f<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } template -static void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_ln_bwd_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; // Instantiate kernel @@ -91,8 +93,8 @@ static void launch_general_(LaunchParams &launch_params, int ctas_per_row = launch_params.params.ctas_per_row; if (configure_params) { int ctas_per_sm; - (void)cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, - Kernel_traits::THREADS_PER_CTA, 0); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0)); const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); @@ -113,10 +115,11 @@ static void launch_general_(LaunchParams &launch_params, dim3 block(Kernel_traits::THREADS_PER_CTA); if (ctas_per_row == 1) { kernel<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *params_ = reinterpret_cast(&launch_params.params); - (void)cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream)); } // Launch finalization kernel @@ -130,6 +133,7 @@ static void launch_general_(LaunchParams &launch_params, dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); kernel_final<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } #define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ @@ -138,8 +142,8 @@ static void launch_general_(LaunchParams &launch_params, void \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_##LAUNCH_TYPE##_( \ - launch_params, configure_params); \ + launch_ln_bwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ REGISTER_NORM_BASE( \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index 222994018..c2c045a15 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -15,15 +15,15 @@ using namespace transformer_engine::normalization; template -static void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_ln_fwd_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &ln_fwd_tuned_kernel; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_col = launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; @@ -54,12 +54,14 @@ static void launch_tuned_(LaunchParams &launch_params, if (ctas_per_row == 1) { kernel<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = reinterpret_cast(&launch_params.params); - (void)cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), + Kernel_traits::SMEM_BYTES_FWD, stream)); } #ifdef __HIP_PLATFORM_AMD__ if (launch_params.params.mxfp8_out) { @@ -70,8 +72,8 @@ static void launch_tuned_(LaunchParams &launch_params, template -static void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_ln_fwd_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &ln_fwd_general_kernel; @@ -84,8 +86,8 @@ static void launch_general_(LaunchParams &launch_params, int ctas_per_row = launch_params.params.ctas_per_row; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0)); const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); @@ -110,10 +112,11 @@ static void launch_general_(LaunchParams &launch_params, dim3 block(Kernel_traits::THREADS_PER_CTA); if (ctas_per_row == 1) { kernel<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *params_ = reinterpret_cast(&launch_params.params); - (void)cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream)); } #ifdef __HIP_PLATFORM_AMD__ if (launch_params.params.mxfp8_out) { @@ -128,8 +131,8 @@ static void launch_general_(LaunchParams &launch_params, void \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_##LAUNCH_TYPE##_( \ - launch_params, configure_params); \ + launch_ln_fwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ REGISTER_NORM_BASE( \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh index 679fda32c..f0cefaded 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh @@ -80,6 +80,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( scale = *reinterpret_cast(params.scale); } compute_t amax = 0; + const bool requires_amax = params.amax != nullptr; for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { Ivec x[LDGS]; @@ -148,9 +149,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( compute_t b_ij = beta[it].data.elt[jt]; compute_t temp_output = g_ij * y_ij + b_ij; - if (params.fp8_out) { + if (requires_amax) { __builtin_assume(amax >= 0); amax = fmaxf(amax, fabsf(temp_output)); + } + if (params.fp8_out) { temp_output = temp_output * scale; } @@ -163,16 +166,17 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( } #endif } - if (params.fp8_out) { - // Reduce amax over block - if (params.amax != nullptr) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); - } + + // Reduce amax over block + if (requires_amax) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); } + } + if (params.fp8_out) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { reciprocal(reinterpret_cast(params.scale_inv), scale); @@ -242,6 +246,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne scale = *reinterpret_cast(params.scale); } compute_t amax = 0; + const bool requires_amax = params.amax != nullptr; for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { const int row = cta_row + warp_m; @@ -310,14 +315,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne } // Apply fp8 factors - if (params.fp8_out) { + if (params.fp8_out || requires_amax) { #pragma unroll for (int jt = 0; jt < NUM_ELTS; jt++) { if (col + jt < params.cols) { compute_t z_ij = z.data.elt[jt]; __builtin_assume(amax >= 0); amax = fmaxf(amax, fabsf(z_ij)); - z.data.elt[jt] = z_ij * scale; + if (params.fp8_out) { + z.data.elt[jt] = z_ij * scale; + } } } } @@ -337,17 +344,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne } } - // Finalize fp8 factors - if (params.fp8_out) { - // Reduce amax over block - if (params.amax != nullptr) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); - } + // Reduce amax over block + if (requires_amax) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); } + } + if (params.fp8_out) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { reciprocal(reinterpret_cast(params.scale_inv), scale); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index d084e5c06..6c85cc432 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -53,6 +53,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens bool is_aligned = true; #ifndef __HIP_PLATFORM_AMD__ bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); + + if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { + NVTE_CHECK(!cudnn_backend, + "cuDNN does not currently support amax output for non quantized output"); + } #endif bool training = @@ -170,7 +175,76 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const } else { NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); plan->execute(x.data.dptr, gamma.data.dptr, nullptr /*mu*/, rsigma.data.dptr, dx->data.dptr, - dz.data.dptr, nullptr /*dbeta*/, dgamma->data.dptr, workspace->data.dptr, stream); + dz.data.dptr, nullptr /*add*/, nullptr /*dbeta*/, dgamma->data.dptr, + workspace->data.dptr, stream); + } + return; +} + +void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const Tensor &rsigma, + const Tensor &gamma, Tensor *dx, Tensor *dgamma, Tensor *workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + using namespace transformer_engine; + + NVTE_CHECK(dz.data.dtype == gamma.data.dtype); + NVTE_CHECK(add.data.dtype == gamma.data.dtype); + NVTE_CHECK(rsigma.data.dtype == DType::kFloat32); + + NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(dz.data.shape == x.data.shape); + NVTE_CHECK(add.data.shape == x.data.shape); + + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); + + NVTE_CHECK(dx->data.shape == x.data.shape); + NVTE_CHECK(dx->data.dtype == x.data.dtype); + + NVTE_CHECK(dgamma->data.shape == gamma.data.shape); + NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(add, "add"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); + } + + // cuDNN does not currently support fused backward+add + NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te; + +#ifndef __HIP_PLATFORM_AMD__ + // TE backend does not currently support zero_centered_gamma_in_weight_dtype + NVTE_CHECK(!use_zero_centered_gamma_in_weight_dtype(), + "zero_centered_gamma_in_weight_dtype is currently not supported for rmsnorm_bwd_add"); +#endif + + bool is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, dgamma->data.dptr, add.data.dptr); + bool gamma_in_weight_dtype = false; + + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::BackwardAdd, + gamma.data.dtype, // wtype + x.data.dtype, // itype + gamma.data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, + gamma_in_weight_dtype); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + plan->execute(x.data.dptr, gamma.data.dptr, nullptr /*mu*/, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, add.data.dptr, nullptr /*dbeta*/, dgamma->data.dptr, + workspace->data.dptr, stream); } return; } @@ -203,3 +277,19 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size convertNVTETensor(dx), convertNVTETensor(dgamma), convertNVTETensor(workspace), multiprocessorCount, zero_centered_gamma, stream); } + +void nvte_rmsnorm_bwd_add(const NVTETensor dz, // Nxhidden_size + const NVTETensor x, // Nxhidden_size + const NVTETensor add, // Nxhidden_size + const NVTETensor rsigma, // N, FP32! + const NVTETensor gamma, // hidden_size + NVTETensor dx, NVTETensor dgamma, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + NVTE_API_CALL(nvte_rmsnorm_bwd_add); + using namespace transformer_engine; + rmsnorm_bwd_add(*convertNVTETensorCheck(dz), *convertNVTETensorCheck(x), + *convertNVTETensorCheck(add), *convertNVTETensorCheck(rsigma), + *convertNVTETensorCheck(gamma), convertNVTETensor(dx), convertNVTETensor(dgamma), + convertNVTETensor(workspace), multiprocessorCount, zero_centered_gamma, stream); +} diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh index 5d8a5b765..3f3cdd065 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh @@ -7,13 +7,31 @@ #ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_ #define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_ +#include + #include "../../utils.cuh" #include "../common.h" namespace transformer_engine { namespace normalization { -template +struct maybe_not_t {}; + +template +using maybe_t = std::conditional_t; + +template +union dx_add_t { + using add_t = maybe_t; + using dx_t = Ivec; + struct { + char _padding[sizeof(dx_t) > sizeof(add_t) ? sizeof(dx_t) - sizeof(add_t) : 0]; + add_t add; + }; + dx_t dx; +}; + +template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_kernel( BackwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; @@ -111,10 +129,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke } } + dx_add_t temp[LDGS]; + if constexpr (FusedAdd) { + idx = row * Ktraits::VEC_COLS + c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + temp[it].add.load_from(params.add, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } + } + reduce_t result = reducer.allreduce({0, mdyy_local}, sum); mdyy_local = Get<1>::of(result) * rn; - Ivec dx[LDGS]; idx = row * Ktraits::VEC_COLS + c; #pragma unroll for (int it = 0; it < LDGS; it++) { @@ -123,9 +150,13 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke compute_t dy_tmp = dy[it * NUM_ELTS + jt]; compute_t y_tmp = y[it * NUM_ELTS + jt]; compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp)); - dx[it].data.elt[jt] = dx_tmp; + if constexpr (FusedAdd) { + compute_t add_tmp = temp[it].add.data.elt[jt]; + dx_tmp += add_tmp; + } + temp[it].dx.data.elt[jt] = dx_tmp; } - dx[it].store_to(params.dx, idx); + temp[it].dx.store_to(params.dx, idx); idx += Ktraits::VEC_COLS_PER_LDG; } } // end: grid stride loop @@ -274,7 +305,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_fi } } -template +template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_kernel( BackwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; @@ -379,14 +410,22 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_ #pragma unroll for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; it++, col += gdimn * NUM_ELTS) { - Ivec dx; + dx_add_t temp; + if constexpr (FusedAdd) { + temp.add.load_from_elts(params.add, row * params.cols + col, params.cols - col); + } #pragma unroll for (int jt = 0; jt < NUM_ELTS; jt++) { compute_t dy_ij = dy[it].data.elt[jt]; compute_t y_ij = y[it].data.elt[jt]; - dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij)); + compute_t dx_ij = rs * (dy_ij - (mdyy * y_ij)); + if constexpr (FusedAdd) { + compute_t add_ij = temp.add.data.elt[jt]; + dx_ij += add_ij; + } + temp.dx.data.elt[jt] = dx_ij; } - dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col); + temp.dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col); } } diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index 8cff9962c..8a5ca4e39 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -14,17 +14,17 @@ using namespace transformer_engine::normalization; template -static void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) + int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL, bool FUSED_ADD = false> +void launch_rmsnorm_bwd_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_bwd_tuned_kernel; + auto kernel = &rmsnorm_bwd_tuned_kernel; if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + int ctas_per_sm = 0; + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_col = launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; @@ -52,13 +52,14 @@ static void launch_tuned_(LaunchParams &launch_params, if (ctas_per_row == 1) { kernel<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = reinterpret_cast(&launch_params.params); - (void)cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), - Kernel_traits::SMEM_BYTES, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), + Kernel_traits::SMEM_BYTES, stream)); } using Kernel_traits_f = @@ -69,19 +70,20 @@ static void launch_tuned_(LaunchParams &launch_params, auto kernel_f = &rmsnorm_bwd_finalize_tuned_kernel; kernel_f<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } template -static void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) + int BYTES_PER_LDG_FINAL, bool FUSED_ADD = false> +void launch_rmsnorm_bwd_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; // Instantiate kernel using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_bwd_general_kernel; + auto kernel = &rmsnorm_bwd_general_kernel; // Configure kernel params const int rows = launch_params.params.rows; @@ -89,9 +91,9 @@ static void launch_general_(LaunchParams &launch_params, int ctas_per_col = launch_params.params.ctas_per_col; int ctas_per_row = launch_params.params.ctas_per_row; if (configure_params) { - int ctas_per_sm; - (void)cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, - Kernel_traits::THREADS_PER_CTA, 0); + int ctas_per_sm = 0; + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0)); const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); @@ -114,10 +116,11 @@ static void launch_general_(LaunchParams &launch_params, dim3 block(Kernel_traits::THREADS_PER_CTA); if (ctas_per_row == 1) { kernel<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *params_ = reinterpret_cast(&launch_params.params); - (void)cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream)); } // Launch finalization kernel @@ -131,6 +134,7 @@ static void launch_general_(LaunchParams &launch_params, dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); kernel_final<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } #define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ @@ -139,15 +143,15 @@ static void launch_general_(LaunchParams &launch_params, void \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_##LAUNCH_TYPE##_( \ - launch_params, configure_params); \ + launch_rmsnorm_bwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ REGISTER_NORM_BASE( \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ } // namespace -// Create rmsnorm tuned launch function and register. Macro signature: +// Create rmsnorm bwd tuned launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL @@ -175,7 +179,7 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1 REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -// Create rmsnorm general launch function and register. Macro signature: +// Create rmsnorm bwd general launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL @@ -208,3 +212,108 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp16, fp16, fp32, REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); + +// Create fused rmsnorm bwd + add tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4, + true); + +// Create fused rmsnorm bwd + add general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, fp16, fp32, fp16, fp32, 4, 1, 8, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, bf16, fp32, bf16, fp32, 4, 1, 8, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, fp16, fp32, fp16, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4, + true); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index 829f22dd4..e001c9c00 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -15,16 +15,16 @@ using namespace transformer_engine::normalization; template -static void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_rmsnorm_fwd_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &rmsnorm_fwd_tuned_kernel; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_col = launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; @@ -55,12 +55,14 @@ static void launch_tuned_(LaunchParams &launch_params, if (ctas_per_row == 1) { kernel<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = reinterpret_cast(&launch_params.params); - (void)cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), + Kernel_traits::SMEM_BYTES_FWD, stream)); } #ifdef __HIP_PLATFORM_AMD__ if (launch_params.params.mxfp8_out) { @@ -71,8 +73,8 @@ static void launch_tuned_(LaunchParams &launch_params, template -static void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_rmsnorm_fwd_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &rmsnorm_fwd_general_kernel; @@ -85,8 +87,8 @@ static void launch_general_(LaunchParams &launch_params, int ctas_per_row = launch_params.params.ctas_per_row; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0)); const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); @@ -111,10 +113,11 @@ static void launch_general_(LaunchParams &launch_params, dim3 block(Kernel_traits::THREADS_PER_CTA); if (ctas_per_row == 1) { kernel<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *params_ = reinterpret_cast(&launch_params.params); - (void)cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream)); } #ifdef __HIP_PLATFORM_AMD__ if (launch_params.params.mxfp8_out) { @@ -129,8 +132,8 @@ static void launch_general_(LaunchParams &launch_params, void \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_##LAUNCH_TYPE##_( \ - launch_params, configure_params); \ + launch_rmsnorm_fwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ REGISTER_NORM_BASE( \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh index 3d77a4710..d057f9beb 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -77,6 +77,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke scale = *reinterpret_cast(params.scale); } compute_t amax = 0; + const bool requires_amax = params.amax != nullptr; for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { Ivec x[LDGS]; @@ -141,9 +142,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke } compute_t temp_output = g_ij * y_ij; - if (params.fp8_out) { + if (requires_amax) { __builtin_assume(amax >= 0); amax = fmaxf(amax, fabsf(temp_output)); + } + if (params.fp8_out) { temp_output = temp_output * scale; } @@ -156,16 +159,17 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke #ifdef __HIP_PLATFORM_AMD__ } #endif - if (params.fp8_out) { - // Reduce amax over block - if (params.amax != nullptr) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); - } + + // Reduce amax over block + if (requires_amax) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); } + } + if (params.fp8_out) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { reciprocal(reinterpret_cast(params.scale_inv), scale); @@ -233,6 +237,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ scale = *reinterpret_cast(params.scale); } compute_t amax = 0; + const bool requires_amax = params.amax != nullptr; for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { const int row = cta_row + warp_m; @@ -286,14 +291,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ } // Apply fp8 factors - if (params.fp8_out) { + if (params.fp8_out || requires_amax) { #pragma unroll for (int jt = 0; jt < NUM_ELTS; jt++) { if (col + jt < params.cols) { compute_t z_ij = z.data.elt[jt]; __builtin_assume(amax >= 0); amax = fmaxf(amax, fabsf(z_ij)); - z.data.elt[jt] = z_ij * scale; + if (params.fp8_out) { + z.data.elt[jt] = z_ij * scale; + } } } } @@ -313,17 +320,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ #endif } - // Finalize fp8 factors - if (params.fp8_out) { - // Reduce amax over block - if (params.amax != nullptr) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); - } + // Reduce amax over block + if (requires_amax) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); } + } + if (params.fp8_out) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { reciprocal(reinterpret_cast(params.scale_inv), scale); diff --git a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu index a18ea6d4a..d5f6aeecc 100644 --- a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu +++ b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu @@ -35,17 +35,20 @@ void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t switch (wait_kind) { case WaitKind::KERNEL_WAIT: wait_until_on_stream_and_reset<<<1, 1, 0, cur_stream>>>(sig_addr, wait_value, signal_reset); + NVTE_CHECK_CUDA(cudaGetLastError()); break; case WaitKind::NVSHMEM_WAIT: nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, cur_stream); - cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset, - CU_STREAM_WRITE_VALUE_DEFAULT); + NVTE_CHECK_CUDA_DRIVER(cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, + (cuuint64_t)signal_reset, + CU_STREAM_WRITE_VALUE_DEFAULT)); break; case WaitKind::STREAM_WAIT: - cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)wait_value, - CU_STREAM_WAIT_VALUE_GEQ); - cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset, - CU_STREAM_WRITE_VALUE_DEFAULT); + NVTE_CHECK_CUDA_DRIVER(cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, + (cuuint64_t)wait_value, CU_STREAM_WAIT_VALUE_GEQ)); + NVTE_CHECK_CUDA_DRIVER(cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, + (cuuint64_t)signal_reset, + CU_STREAM_WRITE_VALUE_DEFAULT)); break; } } diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 5e4f2d0f1..bcdc0d298 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -267,11 +267,13 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, num_out_tokens); + NVTE_CHECK_CUDA(cudaGetLastError()); blocks = num_rows; threads = std::min(num_cols / kElementsPerAccess, 1024); moe_permute_kernel<<>>( input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { // moe_unpermute_bwd @@ -283,6 +285,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, moe_permute_kernel<<>>( input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { // moe_unpermute_bwd with probs @@ -306,6 +309,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, } else { NVTE_ERROR("topK cannot exceed 128."); } + NVTE_CHECK_CUDA(cudaGetLastError()); } } } @@ -330,11 +334,13 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f moe_unpermute_kernel<<>>( input, output, row_id_map, nullptr, num_rows, topK, num_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { // moe_unpermute_fwd with probs moe_unpermute_kernel<<>>( input, output, row_id_map, prob, num_rows, topK, num_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index 3ec6ee849..f3c6b7952 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -53,13 +53,17 @@ __global__ void amax_final_reduce(const float* __restrict__ block_amax, template __launch_bounds__(amax_kernel_threads) __global__ + void amax_kernel(const InputType *input, float *amax, #ifdef __HIP_PLATFORM_AMD__ - void amax_kernel(const InputType *input, float *amax, float* __restrict__ block_amax, const size_t N, - const size_t num_aligned_elements) { + float* __restrict__ block_amax, #else - void amax_kernel(const InputType *input, float *amax, const size_t N, - const size_t num_aligned_elements) { + [[maybe_unused]] void* __restrict__ block_amax, #endif + const size_t N, const size_t num_aligned_elements, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + VectorizedLoader loader(input, N); InputType max{0.f}; const int warp_id = threadIdx.x / THREADS_PER_WARP; @@ -108,10 +112,13 @@ __launch_bounds__(amax_kernel_threads) __global__ } template -void launch_amax_kernel(const InputType *input, float *amax, const size_t N, float *block_amax, - size_t block_capacity, cudaStream_t stream) { +void launch_amax_kernel(const InputType *input, float *amax, const size_t N, +#ifdef __HIP_PLATFORM_AMD__ + float *block_amax, size_t block_capacity, +#endif + const float *noop_ptr, cudaStream_t stream) { // Zero out amax so we can update with atomic max - (void)cudaMemsetAsync(amax, 0, sizeof(float), stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); // Return immediately if tensor is empty if (N == 0) { @@ -128,7 +135,7 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo size_t num_blocks = DIVUP(num_aligned_elements, threads); constexpr size_t max_blocks = 65535; num_blocks = std::min(num_blocks, max_blocks); - + constexpr void* block_amax = nullptr; #else constexpr size_t threads = amax_kernel_threads; size_t num_blocks = nvte_amax_workspace_num_blocks(num_aligned_elements); @@ -139,31 +146,18 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo // Launch kernel switch (align) { case Alignment::SAME_ALIGNED: -#ifdef __HIP_PLATFORM_AMD__ - amax_kernel - <<>>(input, amax, block_amax, N, num_aligned_elements); -#else amax_kernel - <<>>(input, amax, N, num_aligned_elements); -#endif + <<>>(input, amax, block_amax, N, num_aligned_elements, noop_ptr); break; case Alignment::SAME_UNALIGNED: -#ifdef __HIP_PLATFORM_AMD__ amax_kernel - <<>>(input, amax, block_amax, N, num_aligned_elements); -#else - amax_kernel - <<>>(input, amax, N, num_aligned_elements); -#endif + <<>>(input, amax, block_amax, N, num_aligned_elements, noop_ptr); break; case Alignment::DIFFERENT: { // This case is a logic error, since there is only one pointer (input) // in the alignment check. Still safe to process without vectorization. -#ifdef __HIP_PLATFORM_AMD__ - amax_kernel<1, true, InputType><<>>(input, amax, block_amax, N, N); -#else - amax_kernel<1, true, InputType><<>>(input, amax, N, N); -#endif + amax_kernel<1, true, InputType> + <<>>(input, amax, block_amax, N, N, noop_ptr); break; } } @@ -199,14 +193,15 @@ size_t nvte_amax_workspace_num_blocks(size_t N) { #endif -void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { -#ifdef __HIP_PLATFORM_AMD__ - nvte_compute_amax_with_workspace(input_, output_, /*workspace=*/nullptr, stream); -} +namespace { -void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) { +void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream, +#ifdef __HIP_PLATFORM_AMD__ + const NVTETensor workspace_, +#else + [[maybe_unused]] const NVTETensor workspace_, #endif - NVTE_API_CALL(nvte_compute_amax); + const NVTEQuantizationConfig config_) { using namespace transformer_engine; // Check input tensor @@ -258,6 +253,16 @@ void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor } #endif + float *noop_ptr = nullptr; + if (config_ != nullptr) { + const QuantizationConfig *config_cpp = reinterpret_cast(config_); + + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr; + noop_ptr = reinterpret_cast( + (noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr)); + } + // Compute amax TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); @@ -266,15 +271,41 @@ void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor #ifdef __HIP_PLATFORM_AMD__ block_amax, block_capacity, #endif - stream);); // NOLINT(*) + noop_ptr, stream);); // NOLINT(*) +} + +} // anonymous namespace + +void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_amax); + compute_amax_impl(input_, output_, stream, nullptr, nullptr); +} + +void nvte_compute_amax_with_config(const NVTETensor input_, const NVTETensor output_, + const NVTEQuantizationConfig config_, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_amax_with_config); + compute_amax_impl(input_, output_, stream, nullptr, config_); +} + +#ifdef __HIP_PLATFORM_AMD__ +void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, + NVTETensor workspace_, const NVTEQuantizationConfig config_, + cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_amax_with_workspace); + compute_amax_impl(input_, output_, stream, workspace_, config_); } +#endif namespace transformer_engine { namespace { __global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, const float max_fp8, const bool force_pow_2_scales, - const float epsilon) { + const float epsilon, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon, std::numeric_limits::max()); } @@ -320,10 +351,21 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType, max_fp8 = Quantized_Limits::max_norm;); + // noop tensor for cuda graph + float *noop_ptr = nullptr; + if (config_ != nullptr) { + const QuantizationConfig *config_cpp = reinterpret_cast(config_); + + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr; + noop_ptr = reinterpret_cast( + (noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr)); + } + // Update scale compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>( reinterpret_cast(output.amax.dptr), reinterpret_cast(output.scale.dptr), max_fp8, config.force_pow_2_scales, - config.amax_epsilon); + config.amax_epsilon, noop_ptr); NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/recipe/fp8_block_scaling.cu b/transformer_engine/common/recipe/fp8_block_scaling.cu index cdb307238..0ff8c4040 100644 --- a/transformer_engine/common/recipe/fp8_block_scaling.cu +++ b/transformer_engine/common/recipe/fp8_block_scaling.cu @@ -197,6 +197,7 @@ void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_ reinterpret_cast(amax.data.dptr), amax_stride_h, amax_stride_w, h, w, start_offset, len);) + NVTE_CHECK_CUDA(cudaGetLastError()); } void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h, @@ -229,6 +230,7 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s reinterpret_cast(out.data.dptr), reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, h, w, start_offset, len);))) + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace fp8_block_scaling_recipe diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 81a150283..499f7bcff 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -17,8 +17,26 @@ #include "../util/logging.h" #include "transformer_engine/transformer_engine.h" +namespace transformer_engine { namespace { +#ifdef __HIP_PLATFORM_AMD__ +#define __ldg(x) (*(x)) +#endif + +#ifndef __HIP_PLATFORM_AMD__ +constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32; +constexpr __device__ __host__ int TB_DIM = 32; +constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; +constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; + +// output is in ~K-major interleaved blocks +constexpr __device__ __host__ int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; +constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32; +#else +// HIPCC does not support __host__ qualifier for variables +// and constexpr values do not need __device__ qualifier because they are compile-time constants +constexpr int MXFP8_BLOCK_SIZE = 32; constexpr int TB_DIM = 32; constexpr int NEW_SF_TILE_DIM_K = 16; constexpr int N_SF_PER_TD_PER_TILE = 4; @@ -26,6 +44,7 @@ constexpr int N_SF_PER_TD_PER_TILE = 4; // output is in ~K-major interleaved blocks constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; constexpr int NEW_SF_TILE_DIM_M_I32 = 32; +#endif template __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { @@ -53,8 +72,11 @@ __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { } template -__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, - const int K) { +__device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, const int bid_x, + const int bid_y, const int grid_dim_x, + const int grid_dim_y) { constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; @@ -68,21 +90,24 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons int m_tiles_in_tb = N_TILE_PER_TD; int k_tiles_in_tb = TB_DIM; - if (blockIdx.x == gridDim.x - 1) { + if (bid_x == grid_dim_x - 1) { k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1; } - if (blockIdx.y == gridDim.y - 1) { + if (bid_y == grid_dim_y - 1) { m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; } - const int32_t* input_i32 = reinterpret_cast(input) + - blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + - blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + + const int input_offset = + bid_x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bid_y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + const int32_t* input_i32 = reinterpret_cast(input) + input_offset; int32_t* output_i32[N_TILE_PER_TD]; #pragma unroll for (int i = 0; i < m_tiles_in_tb; i++) { - output_i32[i] = reinterpret_cast(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 + - (blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; + output_i32[i] = reinterpret_cast(output) + bid_x * TB_DIM * SF_TILE_SIZE_I32 + + (bid_y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; } extern __shared__ int slm[]; @@ -92,13 +117,18 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons threadIdx.y < k_tiles_in_tb) { #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { -#ifndef __HIP_PLATFORM_AMD__ - regs_vec[i] = __ldg(reinterpret_cast( - input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD)); -#else - regs_vec[i] = *(reinterpret_cast( - input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD)); -#endif + const int thread_offset = + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD; + regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); + // Pad zeros + if (padding_m || padding_k) { + for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { + const int index = (input_offset + thread_offset) * sizeof(int) + j; + if (index / M >= original_K || index % M >= original_M) { + reinterpret_cast(regs_vec + i)[j] = 0; + } + } + } } // local shuffle @@ -133,6 +163,14 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons } } +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + swizzle_col_scaling_kernel(const void* input, void* output, const int M, const int K, + const int original_M, const int original_K) { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); +} + template __device__ inline void regs_shuffle(LType* regs_vec) { constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); @@ -150,8 +188,11 @@ __device__ inline void regs_shuffle(LType* regs_vec) { } template -__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, - const int K) { +__device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, const int bid_x, + const int bid_y, const int grid_dim_x, + const int grid_dim_y) { constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; @@ -161,14 +202,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons int n_tiles_in_tb = N_TILES_IN_TB; const int K_i32 = K / 4; - if (blockIdx.x == gridDim.x - 1) { + if (bid_x == grid_dim_x - 1) { n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; } - const int* input_i32 = reinterpret_cast(input) + - blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB; - int* output_i32 = reinterpret_cast(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + - blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32; + bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + + const int input_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB; + const int* input_i32 = reinterpret_cast(input) + input_offset; + int* output_i32 = reinterpret_cast(output) + bid_y * SF_TILE_DIM_M_I32 * K_i32 + + bid_x * N_TILES_IN_TB * SF_TILE_SIZE_I32; extern __shared__ int4 slm_v4i[]; @@ -177,13 +221,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { -#ifndef __HIP_PLATFORM_AMD__ - regs_vec[i] = __ldg(reinterpret_cast( - input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD)); -#else - regs_vec[i] = *(reinterpret_cast( - input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD)); -#endif + const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD; + regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); + if (padding_m || padding_k) { + // Pad zeros + for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { + const int index = (input_offset + thread_offset) * sizeof(int) + j; + if (index / K >= original_M || index % K >= original_K) { + reinterpret_cast(regs_vec + i)[j] = 0; + } + } + } } // shuffle regs @@ -208,9 +256,99 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons } } -} // namespace +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K, + const int original_M, const int original_K) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); +} -namespace transformer_engine { +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB +struct MultiSwizzleArgs { + // (input) Data buffers for input scaling factors + void* input_list[kMaxTensorsPerKernel]; + // (output) Data buffers for swizzled scaling factors + void* output_list[kMaxTensorsPerKernel]; + // Input scaling factor m + int m_list[kMaxTensorsPerKernel]; + // Input scaling factor k + int k_list[kMaxTensorsPerKernel]; + // Input scaling factor m before padding + int original_m_list[kMaxTensorsPerKernel]; + // Input scaling factor k before padding + int original_k_list[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of CUDA blocks needed for each + // tensor + int block_range[kMaxTensorsPerKernel + 1]; + // Number of tensors being processed by kernel + int num_tensors; +}; + +template +__global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) { + // Find tensor corresponding to block + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + // Get args corresponding to block + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; + + // Get block index in grid. Emulate 2D grid. + const int num_tiles_k = K / SF_TILE_DIM_K; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int grid_dim_x = DIVUP(num_tiles_k, N_TILES_IN_TB); + const int grid_dim_y = num_tiles_m; + const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; + const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; + + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); +} + +template +__global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_args) { + // Find tensor corresponding to block + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + // Get args corresponding to block + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + + // Get block index in grid. Emulate 2D grid. + const int num_tiles_k = K / SF_TILE_DIM_K; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int grid_dim_x = DIVUP(num_tiles_k, TB_DIM); + const int grid_dim_y = DIVUP(num_tiles_m, N_TILE_PER_TD); + const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; + const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; + + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); +} + +} // namespace void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { @@ -264,38 +402,44 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s int n_tiles_in_tb = TB_DIM * vec_load_size; dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = input->flat_first_dim(); + const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; switch (vec_load_size) { case 4: #ifndef __HIP_PLATFORM_AMD__ - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); #endif swizzle_row_scaling_kernel - <<>>(input->scale_inv.dptr, - output->scale_inv.dptr, m, k); + <<>>( + input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; case 2: #ifndef __HIP_PLATFORM_AMD__ - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); #endif swizzle_row_scaling_kernel - <<>>(input->scale_inv.dptr, - output->scale_inv.dptr, m, k); + <<>>( + input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; case 1: #ifndef __HIP_PLATFORM_AMD__ - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); #endif swizzle_row_scaling_kernel - <<>>(input->scale_inv.dptr, - output->scale_inv.dptr, m, k); + <<>>( + input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; default: NVTE_ERROR("Not valid vec_load_size."); break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } if (input->has_columnwise_data()) { int vec_load_size = (num_tiles_m - 1) % 4 + 1; @@ -303,48 +447,259 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s int n_tiles_in_tb = TB_DIM * vec_load_size; dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = input->flat_last_dim(); + const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; switch (vec_load_size) { case 4: #ifndef __HIP_PLATFORM_AMD__ - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); #endif swizzle_col_scaling_kernel - <<>>( - input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); break; case 2: #ifndef __HIP_PLATFORM_AMD__ - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); #endif swizzle_col_scaling_kernel - <<>>( - input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); break; case 1: #ifndef __HIP_PLATFORM_AMD__ - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); #endif swizzle_col_scaling_kernel - <<>>( - input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); break; default: NVTE_ERROR("Not valid vec_load_size."); break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } // 2D block scaling } else { NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); } - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) { - printf("CUDA Error: %s\n", cudaGetErrorString(err)); - exit(-1); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, + const int vec_load_size, const bool is_rowwise, + cudaStream_t stream) { + int n_tiles_in_tb = TB_DIM * vec_load_size; + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + /* Calculate number of CUDA blocks needed for each tensor. + * We have to do it here because we have to iterate over all tensors in this batch to + * get the minimum vec_load_size. + */ + for (size_t j = 0; j < kernel_args.num_tensors; j++) { + const int m = kernel_args.m_list[j]; + const int k = kernel_args.k_list[j]; + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + if (is_rowwise) { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m; + } else { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + + DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size); + } + } + // Launch kernel + const int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; + dim3 block_size(TB_DIM, TB_DIM); + if (is_rowwise) { + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 2: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 1: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } else { + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 2: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 1: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + multi_tensor_swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + NVTE_CHECK_CUDA(cudaGetLastError()); +} +void multi_tensor_swizzle_scaling_factors(const std::vector& input, + std::vector& output, cudaStream_t stream) { + auto num_tensors = input.size(); + bool all_has_data = true; + bool all_has_columnwise_data = true; + for (size_t i = 0; i < num_tensors; i++) { + if (!is_fp8_dtype(input[i]->dtype()) || !is_mxfp_scaling(input[i]->scaling_mode)) { + NVTE_ERROR("Not implemented caling mode " + to_string(input[i]->scaling_mode) + "."); + } + // We don't allow empty tensors. They should be filtered out before calling this function. + if (input[i]->data.numel() == 0) { + NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty."); + } + CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]"); + CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); + all_has_data &= input[i]->has_data(); + all_has_columnwise_data &= input[i]->has_columnwise_data(); + } + NVTE_CHECK(all_has_data || all_has_columnwise_data, + "All tensors should have data or columnwise data."); + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + if (all_has_data) { + MultiSwizzleArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + int vec_load_size = 4; + for (size_t i = 0; i < num_tensors; i++) { + //Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, true, stream); + // Reset the argument struct and vec_load_size + kernel_args.num_tensors = 0; + vec_load_size = 4; + } + const int m = input[i]->scale_inv.shape[0]; + const int k = input[i]->scale_inv.shape[1]; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + NVTE_CHECK( + m * k == std::accumulate(output[i]->scale_inv.shape.begin(), + output[i]->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + + int num_tiles_k = k / SF_TILE_DIM_K; + int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; + // We use the minimum vec_load_size across all tensors. + vec_load_size = std::min(vec_load_size, vec_load_size_i); + + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input[i]->scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->scale_inv.dptr; + kernel_args.m_list[pos] = m; + kernel_args.k_list[pos] = k; + kernel_args.original_m_list[pos] = input[i]->flat_first_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / MXFP8_BLOCK_SIZE; + kernel_args.num_tensors++; + } + // Launch the remaining tensors + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, true, stream); + } + + if (all_has_columnwise_data) { + MultiSwizzleArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + int vec_load_size = 4; + for (size_t i = 0; i < num_tensors; i++) { + //Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, false, stream); + // Reset the argument struct and vec_load_size + kernel_args.num_tensors = 0; + vec_load_size = 4; + } + const int m = input[i]->columnwise_scale_inv.shape[1]; + const int k = input[i]->columnwise_scale_inv.shape[0]; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(), + output[i]->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + + int num_tiles_k = k / SF_TILE_DIM_K; + int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; + // We use the minimum vec_load_size across all tensors. + vec_load_size = std::min(vec_load_size, vec_load_size_i); + + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input[i]->columnwise_scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr; + kernel_args.m_list[pos] = m; + kernel_args.k_list[pos] = k; + kernel_args.original_m_list[pos] = input[i]->flat_last_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / MXFP8_BLOCK_SIZE; + kernel_args.num_tensors++; + } + // Launch the remaining tensors + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, false, stream); } } } // namespace transformer_engine @@ -359,3 +714,16 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud using namespace transformer_engine; swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); } + +void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, + const size_t num_tensors, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_swizzle_scaling_factors); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + std::vector input_list, output_list; + for (size_t i = 0; i < num_tensors; i++) { + input_list.push_back(convertNVTETensorCheck(inputs[i])); + output_list.push_back(convertNVTETensorCheck(outputs[i])); + } + multi_tensor_swizzle_scaling_factors(input_list, output_list, stream); +} diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index e24e4d33d..68d1f0ec5 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -199,7 +199,8 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt } } else { NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); - NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name); + // Note: amax is supported for non-FP8 output as it can be fused into the computation + // and later used for quantization with no need to compute it separately NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name); @@ -545,11 +546,11 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { // Zero out tensor data if allocated if (t.data.dptr != nullptr) { const size_t size_in_bytes = nvte_tensor_size_bytes(tensor); - (void)cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream)); } // Set amax to 0 if allocated if (t.amax.dptr != nullptr) { - (void)cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream)); } } diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index fdf92938c..1537b4181 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -429,6 +429,7 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu static_cast(output.scale.dptr), static_cast(output.amax.dptr), static_cast(output.scale_inv.dptr), row_length, num_rows); + NVTE_CHECK_CUDA(cudaGetLastError()); } } else { NVTE_ERROR("Not implemented scaling mode: ", to_string(output.scaling_mode)); diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index a73723926..abfa226e8 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -27,7 +27,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream); + const SimpleTensor &noop_tensor, cudaStream_t stream); // enum class for rowwise usage enum class FP8BlockwiseRowwiseOption { @@ -59,7 +59,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor SimpleTensor &output_t, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, - const bool pow_2_scale, cudaStream_t stream); + const bool pow_2_scale, const SimpleTensor &noop_tensor, + cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index 8a2c39def..e8859fe66 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -336,6 +336,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt reinterpret_cast(dbias->data.dptr), reinterpret_cast(workspace.data.dptr), reduce_dbias_row_length, reduce_dbias_num_rows); + NVTE_CHECK_CUDA(cudaGetLastError()); #ifdef __HIP_PLATFORM_AMD__ } #endif @@ -849,15 +850,16 @@ void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor * NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); #ifndef __HIP_PLATFORM_AMD__ - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( cast_transpose_fused_kernel_notaligned, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + cudaFuncAttributePreferredSharedMemoryCarveout, 100)); #endif cast_transpose_fused_kernel_notaligned <<>>( param, row_length, num_rows, num_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); } if constexpr (IS_DBIAS) { @@ -1311,10 +1313,10 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu (THREADS_PER_WARP + 1) * sizeof(Vec); if (full_tile) { #ifndef __HIP_PLATFORM_AMD__ - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( dgated_act_cast_transpose_kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + cudaFuncAttributePreferredSharedMemoryCarveout, 100)); #endif dgated_act_cast_transpose_kernel(output->amax.dptr), reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, n_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { #ifndef __HIP_PLATFORM_AMD__ - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( dgated_act_cast_transpose_kernel_notaligned, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + cudaFuncAttributePreferredSharedMemoryCarveout, 100)); #endif dgated_act_cast_transpose_kernel_notaligned @@ -1346,6 +1349,7 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu reinterpret_cast(output->amax.dptr), reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, n_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); }); // NOLINT(*) ); // NOLINT(*) } diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index 2be365465..bf3856568 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -258,6 +258,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector <<>>(kernel_args_aligned);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); kernel_args_aligned.num_tensors = 0; } if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) { @@ -271,6 +272,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector <<>>(kernel_args_unaligned);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); kernel_args_unaligned.num_tensors = 0; } @@ -311,6 +313,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector <<>>(kernel_args_aligned);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } if (kernel_args_unaligned.num_tensors > 0) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( @@ -323,6 +326,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector <<>>(kernel_args_unaligned);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 79d8d215f..c3f085b87 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -70,11 +70,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, const __grid_constant__ CUtensorMap tensor_map_output_t, - bool pow_2_scaling) { + bool pow_2_scaling, const float* noop_ptr) { using IVec = Vec; using OVecCast = Vec; using OVecTrans = Vec; + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + // shared mem for amax reduction in entire block, each warp produces one amax, there are // NUM_WARPS_IN_BLOCK amax to reduce __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; @@ -249,11 +253,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, - bool pow_2_scaling) { + bool pow_2_scaling, const float* noop_ptr) { using IVec = Vec; using OVecCast = Vec; using OVecTrans = Vec; + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + // shared mem for amax reduction in entire block, each warp produces one amax, there are // NUM_WARPS_IN_BLOCK amax to reduce __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; @@ -473,8 +481,10 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream) { + const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_square_blockwise); + checkCuDriverContext(stream); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_rows = 1; @@ -492,6 +502,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor size_t scale_t_stride_x = 0; size_t scale_t_stride_y = 0; + const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); + if (return_transpose) { NVTE_CHECK(output_t.shape.size() == input.shape.size(), "output_t must have same number of dimensions as input."); @@ -539,7 +551,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, - tensor_map_output_trans, pow_2_scale); + tensor_map_output_trans, pow_2_scale, noop_ptr); } else { block_scaled_cast_transpose_kernel_notaligned @@ -550,7 +562,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, - pow_2_scale); + pow_2_scale, noop_ptr); } // full-tile ) // return_transpose ) // OutputType diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 6f5c0f3a6..d38bf7996 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -172,7 +172,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, - const bool pow_2_scaling) { + const bool pow_2_scaling, const float* noop_ptr) { + // skip execution if noop + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE; bool return_columnwise_gemm_ready = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; @@ -520,7 +525,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor SimpleTensor& output_t, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, - const bool pow2_scale, cudaStream_t stream) { + const bool pow2_scale, const SimpleTensor& noop_tensor, + cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; @@ -573,23 +579,30 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor "Input and output_t must have the same shape for columnwise non-transpose case."); } } - - NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype."); + if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) { + // output may not be defined if rowwise quantization is not needed. + NVTE_CHECK(output.dtype == output_t.dtype, + "output and output_t need to have the same dtype."); + } NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2."); bool columnwise_compact = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT; size_t scale_t_k = scale_inv_t.shape[1]; scale_t_stride_x = columnwise_compact ? 1 : scale_t_k; scale_t_stride_y = columnwise_compact ? scale_t_k : 1; } + auto output_dtype = + rowwise_option != FP8BlockwiseRowwiseOption::NONE ? output.dtype : output_t.dtype; const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); + const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output.dtype, OutputType, + output_dtype, OutputType, dim3 grid(num_blocks_x, num_blocks_y, 1); @@ -613,9 +626,9 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, - columnwise_option, pow2_scale);) // kAligned - ) // OutputType - ) // InputType + columnwise_option, pow2_scale, noop_ptr);) // kAligned + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/transpose/rtc/swap_first_dims.cu b/transformer_engine/common/transpose/rtc/swap_first_dims.cu new file mode 100644 index 000000000..89a07697a --- /dev/null +++ b/transformer_engine/common/transpose/rtc/swap_first_dims.cu @@ -0,0 +1,37 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "utils.cuh" + +using namespace transformer_engine; + +namespace { + +// Parameters +using VectorType = BytesToType<__VECTOR_SIZE__>::Type; +constexpr size_t block_size = __BLOCK_SIZE__; + +} // namespace + +__global__ void __launch_bounds__(block_size) + swap_first_dims_kernel(const VectorType* __restrict__ const input, + VectorType* __restrict__ const output, const size_t dim0, + const size_t dim1, const size_t dim2) { + const size_t gid = threadIdx.x + blockIdx.x * block_size; +#if __SINGLE_LOAD_STORE__ + const auto idx = gid; +#else + const size_t nthreads = gridDim.x * block_size; + for (size_t idx = gid; idx < dim0 * dim1 * dim2; idx += nthreads) +#endif // __SINGLE_LOAD_STORE__ + { + const auto idx2 = idx % dim2; + const auto idx1 = (idx / dim2) % dim1; + const auto idx0 = (idx / dim2) / dim1; + const auto in_offset = idx1 * dim0 * dim2 + idx0 * dim2 + idx2; + output[idx] = input[in_offset]; + } +} diff --git a/transformer_engine/common/transpose/swap_first_dims.cu b/transformer_engine/common/transpose/swap_first_dims.cu new file mode 100644 index 000000000..08249a823 --- /dev/null +++ b/transformer_engine/common/transpose/swap_first_dims.cu @@ -0,0 +1,209 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "../util/logging.h" +#include "../util/rtc.h" +#include "../util/string.h" + +namespace transformer_engine { + +namespace { + +// String with RTC kernel implementation +#include "string_code_transpose_rtc_swap_first_dims_cu.h" + +// Hard-coded kernel parameters +constexpr size_t block_size = 128; + +/* Performance heuristics for optimized kernel parameters */ +struct KernelConfig { + /* Vector load/store size */ + size_t vector_size; + + /* Whether config is valid */ + bool valid = false; + /* Number of CUDA blocks */ + size_t num_blocks = 0; + /* Whether each thread needs to make exactly one load/store */ + bool single_load_store = true; + + /* Number of active SMs */ + size_t active_sm_count = 0; + /* Used bytes per L1 cache load */ + size_t bytes_per_load = 0; + /* Used bytes per L1 cache store */ + size_t bytes_per_store = 0; + + KernelConfig(size_t dim0, size_t dim1, size_t dim2, size_t sm_count, size_t vector_size_) + : vector_size{vector_size_} { + // Check that tiles are correctly aligned + if (dim2 % vector_size_ != 0) { + return; + } + valid = true; + + // Number of CUDA blocks + num_blocks = DIVUP(dim0 * dim1 * dim2 / vector_size, block_size); + if (num_blocks > 2147483647ull) { + // Maximum number of CUDA blocks + single_load_store = false; + num_blocks = 2147483647ull; + } else if (num_blocks * block_size != dim0 * dim1 * dim2 / vector_size) { + single_load_store = false; + } + + // SM occupancy + constexpr size_t warp_size = 32; + constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs + active_sm_count = std::min(DIVUP(num_blocks * block_size / warp_size, warps_per_sm), sm_count); + + // L1 cache efficiency + constexpr size_t cache_line_size = 128; + bytes_per_store = std::min(cache_line_size, warp_size * vector_size); // Contiguous writes + bytes_per_load = bytes_per_store; + if (dim2 % (vector_size * warp_size) != 0) { + // Some warps are reading from two non-contiguous regions + bytes_per_load /= 2; + } + } + + /* Compare by estimated cost */ + bool operator<(const KernelConfig &other) const { + if (this->valid && other.valid) { + // cost ~ (1/bytes_per_load + 1/bytes_per_store) / active_sms + // Note: Integer arithmetic ensures stable ordering + const auto &l1 = this->bytes_per_load; + const auto &s1 = this->bytes_per_store; + const auto &p1 = this->active_sm_count; + const auto &l2 = other.bytes_per_load; + const auto &s2 = other.bytes_per_store; + const auto &p2 = other.active_sm_count; + const auto scale = l1 * s1 * p1 * l2 * s2 * p2; + const auto cost1 = (scale / l1 + scale / s1) / p1; + const auto cost2 = (scale / l2 + scale / s2) / p2; + return cost1 < cost2; + } else { + return this->valid && !other.valid; + } + } +}; + +template +__global__ void __launch_bounds__(block_size) + swap_first_dims_untuned_kernel(const Type *__restrict__ input, Type *__restrict__ output, + const size_t dim0, const size_t dim1, const size_t dim2) { + const size_t gid = threadIdx.x + blockIdx.x * block_size; + const size_t nthreads = gridDim.x * block_size; + for (size_t idx = gid; idx < dim0 * dim1 * dim2; idx += nthreads) { + const auto idx2 = idx % dim2; + const auto idx1 = (idx / dim2) % dim1; + const auto idx0 = (idx / dim2) / dim1; + const auto in_offset = idx1 * dim0 * dim2 + idx0 * dim2 + idx2; + output[idx] = input[in_offset]; + } +} + +} // namespace + +void swap_first_dims(const Tensor &input, Tensor &output, cudaStream_t stream) { + // Check tensors + NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be simple tensor, but scaling mode is ", + to_string(input.scaling_mode), "."); + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Output tensor must be simple tensor, but scaling mode is ", + to_string(output.scaling_mode), "."); + NVTE_CHECK(input.dtype() == output.dtype(), "Input tensor (dtype=", to_string(input.dtype()), + ") and output tensor (dtype=", to_string(output.dtype()), ") do not match."); + NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); + NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated."); + + // Check tensor dimensions + const auto input_shape = input.shape(); + const auto output_shape = output.shape(); + NVTE_CHECK(input_shape.size() >= 2, "Invalid input tensor dimensions (shape=", input_shape, ")."); + NVTE_CHECK(output_shape.size() == input_shape.size(), "Input tensor (shape=", input_shape, + ") and output tensor (shape=", output_shape, ") do not match."); + NVTE_CHECK(input_shape[0] == output_shape[1], "Input tensor (shape=", input_shape, + ") and output tensor (shape=", output_shape, ") do not match."); + NVTE_CHECK(input_shape[1] == output_shape[0], "Input tensor (shape=", input_shape, + ") and output tensor (shape=", output_shape, ") do not match."); + for (size_t i = 2; i < input_shape.size(); ++i) { + NVTE_CHECK(input_shape[i] == output_shape[i], "Input tensor (shape=", input_shape, + ") and output tensor (shape=", output_shape, ") do not match."); + } + + // Reinterpret tensors as 3D tensors of bytes + const size_t dim0 = output_shape[0]; + const size_t dim1 = output_shape[1]; + size_t dim2 = 1; + for (size_t i = 2; i < output_shape.size(); ++i) { + dim2 *= output_shape[i]; + } + dim2 = get_buffer_size_bytes(dim2, output.dtype()); + + // Choose kernel config with performance heuristics + const size_t sm_count = static_cast(cuda::sm_count()); + KernelConfig config(dim0, dim1, dim2, sm_count, 1); + if (rtc::is_enabled()) { + auto try_config = [&](size_t vector_size) { + KernelConfig new_config(dim0, dim1, dim2, sm_count, vector_size); + if (new_config < config) { + config = new_config; + } + }; + try_config(16); + try_config(8); + try_config(4); + try_config(2); + } + const size_t vector_size = config.vector_size; + + // Launch kernel + if (vector_size == 1) { + // General kernel + swap_first_dims_untuned_kernel<<>>( + static_cast(input.data.dptr), static_cast(output.data.dptr), + dim0, dim1, dim2); + NVTE_CHECK_CUDA(cudaGetLastError()); + } else { + // Compile NVRTC kernel if needed + auto &rtc_manager = rtc::KernelManager::instance(); + const std::string kernel_label = + concat_strings("swap_first_dims,vector_size=", vector_size, + ",single_load_store=", config.single_load_store); + if (!rtc_manager.is_compiled(kernel_label)) { + std::string code = string_code_transpose_rtc_swap_first_dims_cu; + code = regex_replace(code, "__VECTOR_SIZE__", vector_size); + code = regex_replace(code, "__BLOCK_SIZE__", block_size); + code = + regex_replace(code, "__SINGLE_LOAD_STORE__", static_cast(config.single_load_store)); + rtc_manager.compile(kernel_label, "swap_first_dims_kernel", code, + "transformer_engine/common/transpose/rtc/swap_first_dims.cu"); + } + + // Launch NVRTC kernel + rtc_manager.launch(kernel_label, config.num_blocks, block_size, 0, stream, input.data.dptr, + output.data.dptr, dim0, dim1, dim2 / vector_size); + } +} + +} // namespace transformer_engine + +void nvte_swap_first_dims(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_swap_first_dims); + using namespace transformer_engine; + swap_first_dims(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), stream); +} diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index 103f45cf1..9f0acd807 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -279,6 +279,7 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr static_cast(noop.data.dptr), static_cast(output.data.dptr), row_length, num_rows); + NVTE_CHECK_CUDA(cudaGetLastError()); }); // NOLINT(*) } diff --git a/transformer_engine/common/transpose/transpose_fusion.cu b/transformer_engine/common/transpose/transpose_fusion.cu index af859471e..6d656f575 100644 --- a/transformer_engine/common/transpose/transpose_fusion.cu +++ b/transformer_engine/common/transpose/transpose_fusion.cu @@ -422,6 +422,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt reinterpret_cast(dbias->data.dptr), reinterpret_cast(workspace.data.dptr), reduce_dbias_row_length, reduce_dbias_num_rows); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor *dbias, @@ -479,20 +480,24 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor if (full_tile) { #ifndef __HIP_PLATFORM_AMD__ - cudaFuncSetAttribute(transpose_dbias_kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + NVTE_CHECK_CUDA(cudaFuncSetAttribute(transpose_dbias_kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, + 100)); #endif //#ifndef __HIP_PLATFORM_AMD__ transpose_dbias_kernel <<>>( param, row_length, num_rows, n_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { #ifndef __HIP_PLATFORM_AMD__ - cudaFuncSetAttribute(transpose_dbias_kernel_notaligned, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(transpose_dbias_kernel_notaligned, + cudaFuncAttributePreferredSharedMemoryCarveout, 100)); #endif //#ifndef __HIP_PLATFORM_AMD__ transpose_dbias_kernel_notaligned <<>>( param, row_length, num_rows, n_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); } reduce_dbias(*workspace, dbias, row_length, num_rows, nvec_out, diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index a5a23c1c0..dcb3aa42d 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -34,15 +34,9 @@ namespace transformer_engine { -template -__device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) { - return DIVUP(static_cast(N), static_cast(M)) * M; -} - namespace gated_kernels { #ifndef __HIP_PLATFORM_AMD__ -constexpr size_t ALIGNMENT_SIZE = 128; constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; constexpr size_t THREADS_PER_CHUNK = 512; @@ -72,30 +66,31 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float *const scale_ptr, const size_t rows, const size_t cols) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; - const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X; - const int thread_offset_Y = tid_Y; - const int thread_offset_X = tid_X; + const size_t thread_offset_Y = tid_Y; + const size_t thread_offset_X = tid_X; float amax = 0; const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - extern __shared__ char dshmem_unaligned[]; - const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); - const uint64_t dshmem_aligned_as_uint = - DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; - char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; constexpr size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); constexpr size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; @@ -104,8 +99,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t in_mem = in_act_mem + in_gate_mem; constexpr size_t out_act_mem = buff_size_aligned_out; - - // const size_t in_transaction_size = grad_mem + in_mem; constexpr size_t in_transaction_size = buff_elems * sizeof(IType); // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned @@ -146,12 +139,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int it = 0; it < ITERATIONS; ++it) { - const int buff = it % BUFFERS_NUM; - const int next_it = it + 1; + const size_t buff = it % BUFFERS_NUM; + const size_t next_it = it + 1; if (next_it < ITERATIONS) { - const int next_buff = next_it % BUFFERS_NUM; - const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; + const size_t next_buff = next_it % BUFFERS_NUM; + const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; if constexpr (IS_DGATED) { copy_2d_to_sharedx3( &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, @@ -179,10 +172,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; + const size_t shmem_offset_x = thread_offset_X; + const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; float act_elt = static_cast(in_act_sh_curr[shmem_idx]); float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); @@ -225,8 +218,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { - const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; + const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; // dGeLU ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, @@ -277,9 +270,34 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +namespace mxfp8_kernel { + +constexpr size_t CHUNK_DIM_Y = 64; +constexpr size_t CHUNK_DIM_X = 64; +constexpr size_t THREADS_PER_CHUNK_COLWISE = 128; +constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = CHUNK_DIM_X; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = 32; +constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; +constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; +static_assert(BUFF_DIM_Y == 32); + +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 + template + bool ROWWISE_SCALING, bool COLWISE_SCALING, size_t THREADS_PER_CHUNK> __global__ void __launch_bounds__(THREADS_PER_CHUNK) cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, const __grid_constant__ CUtensorMap tensor_map_input_act, @@ -292,43 +310,73 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; - constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + static_assert(STAGES >= 1); + + constexpr bool IS_CACHED_ACT_OP = ROWWISE_SCALING && COLWISE_SCALING; + constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING); + + // # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension. + constexpr size_t COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X); + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_Y_colwise = threadIdx.x / CHUNK_DIM_X; + const size_t tid_X_colwise = threadIdx.x % CHUNK_DIM_X; - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const size_t thread_offset_Y_colwise = tid_Y_colwise; + const size_t thread_offset_X_colwise = tid_X_colwise; - const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; - const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; - const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; - const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t col_base_rowwise = block_offset_X + thread_offset_X_rowwise; + const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; - const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols); + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - const int thread_offset_Y = tid_Y; - const int thread_offset_X = tid_X; + const size_t gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X; + const size_t gate_scale_idx_offset_colwise = cols; - const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols); + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; - extern __shared__ char dshmem_unaligned[]; - const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); - const uint64_t dshmem_aligned_as_uint = - DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; - char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + constexpr size_t SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1; + __shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X]; - const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_elems_total = BUFFERS_NUM * buff_elems; - const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); @@ -337,12 +385,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t in_mem = in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; + const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); const size_t out_mem = out_act_mem + out_gate_mem; - // const size_t in_transaction_size = grad_mem + in_mem; - const size_t in_transaction_size = (IS_DGATED ? 3 : 2) * buff_elems * sizeof(IType); - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned IType *in_grad_sh = reinterpret_cast(dshmem); IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); @@ -354,379 +399,502 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) OType *out_act_colwise_sh = out_act_rowwise_sh; OType *out_gate_colwise_sh = out_gate_rowwise_sh; - if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + if constexpr (ROWWISE_SCALING && COLWISE_SCALING) { out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); out_gate_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem + out_act_mem); } - const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); - const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); - const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); - const uint64_t *TMAP_output_act_rowwise = - reinterpret_cast(&tensor_map_output_act_rowwise); - const uint64_t *TMAP_output_gate_rowwise = - reinterpret_cast(&tensor_map_output_gate_rowwise); - const uint64_t *TMAP_output_act_colwise = - reinterpret_cast(&tensor_map_output_act_colwise); - const uint64_t *TMAP_output_gate_colwise = - reinterpret_cast(&tensor_map_output_gate_colwise); + IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations + IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values - __shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X]; + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); // Initialize shared memory barrier with the number of threads participating in the barrier. #pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + __shared__ alignas(8) uint64_t mbar[STAGES]; - const bool is_master_thread = (threadIdx.x == 0); - - if (is_master_thread) { -// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_init(&mbar[it], THREADS_PER_CHUNK); - } - ptx::fence_proxy_async_shared_cta(); - } - // Syncthreads so initialized barrier is visible to all threads. - __syncthreads(); + initialize_barriers(mbar, is_master_thread); int parity = 0; - // Prefetch data of the first stage - if (is_master_thread) { - // Initiate bulk tensor copy - // Grad - if constexpr (IS_DGATED) { - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_grad_sh[0]), - TMAP_grad_in, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - } - - // Act - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_act_sh[0]), - TMAP_in_act, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - - // Gate - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_gate_sh[0]), - TMAP_in_gate, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - - // Arrive on the barrier and tell how many bytes are expected to come in. - ptx::mbarrier_arrive_expect_tx(&mbar[0], in_transaction_size); + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(&in_grad_sh[0], &tensor_map_grad, block_offset_X, block_offset_Y, + &in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y, + &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y, + shmem_buff_size, &mbar[0], is_master_thread); } else { - // Other threads just arrive - ptx::mbarrier_arrive(&mbar[0]); + copy_2d_to_sharedx2(&in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y, + &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y, + shmem_buff_size, &mbar[0], is_master_thread); } #pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - const int buff = it % BUFFERS_NUM; - const int next_it = it + 1; - const size_t row_base = chunk_offset_Y + it * BUFFER_DIM_Y; - if (next_it < ITERATIONS) { - if (is_master_thread) { - const int next_buff = next_it % BUFFERS_NUM; - const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - // Initiate bulk tensor copy - if constexpr (IS_DGATED) { - // Grad - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_grad_sh[next_buff * buff_elems]), TMAP_grad_in, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - } - // Act - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_act_sh[next_buff * buff_elems]), TMAP_in_act, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - // Gate - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_gate_sh[next_buff * buff_elems]), TMAP_in_gate, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - - // Arrive on the barrier and tell how many bytes are expected to come in. - ptx::mbarrier_arrive_expect_tx(&mbar[next_it], in_transaction_size); + for (int stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X, + global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act, + global_offset_X, global_offset_Y, &in_gate_sh[next_buff_offset], + &tensor_map_input_gate, global_offset_X, global_offset_Y, + shmem_buff_size, &mbar[next_stage], is_master_thread); } else { - // Other threads just arrive - ptx::mbarrier_arrive(&mbar[next_it]); + copy_2d_to_sharedx2(&in_act_sh[next_buff_offset], &tensor_map_input_act, global_offset_X, + global_offset_Y, &in_gate_sh[next_buff_offset], &tensor_map_input_gate, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); } } ptx::fence_proxy_async_shared_cta(); // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[it], parity); + ptx::mbarrier_wait_parity(&mbar[stage], parity); - IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; - IType *in_act_sh_curr = in_act_sh + buff * buff_elems; - IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; - OType *out_act_rowwise_sh_curr = out_act_rowwise_sh + buff * buff_elems; - OType *out_gate_rowwise_sh_curr = out_gate_rowwise_sh + buff * buff_elems; - OType *out_act_colwise_sh_curr = out_act_colwise_sh + buff * buff_elems; - OType *out_gate_colwise_sh_curr = out_gate_colwise_sh + buff * buff_elems; - - // Assuming one iteration covers exactly 32 rows - const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; - const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y; - - float after_dact_reg[BUFFER_STAGES_NUM]; - float after_dgate_reg[BUFFER_STAGES_NUM]; - float thread_Y_mx_block_amax = 0.0f; - float thread_Y_mx_block_amax_gate = 0.0f; + if constexpr (COLWISE_SCALING) { + const size_t shmem_offset_base_colwise = + buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise; + float thread_amax_act = 0.0f; + float thread_amax_gate = 0.0f; + float after_act_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE]; + float after_gate_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE]; +// 1. Read/Compute elements. Find MXFP8-block AMAX #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = (row >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { + const size_t shmem_offset_colwise = + shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; - float act_elt = static_cast(in_act_sh_curr[shmem_idx]); - float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); - - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); - const float x = act_elt; - float act_x; - float dact_x; + float act_elt = static_cast(in_act_sh[shmem_offset_colwise]); + float gate_elt = static_cast(in_gate_sh[shmem_offset_colwise]); + float after_act_elt; + float after_gate_elt; - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_act_elt = dact_x * grad_elt * gate_elt; + after_gate_elt = act_x * grad_elt; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + after_act_elt = ActOP(act_elt, {}) * gate_elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + after_act_elt = static_cast(static_cast(after_act_elt)); + if constexpr (IS_DGATED) { + after_gate_elt = static_cast(static_cast(after_gate_elt)); + } } - after_dact_reg[stage] = dact_x * grad_elt * gate_elt; - after_dgate_reg[stage] = act_x * grad_elt; - } else { - after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; - } - if constexpr (USE_ROWWISE_SCALING) { + after_act_colwise[i] = after_act_elt; if constexpr (IS_DGATED) { - // dgate - float amax = fabsf(after_dgate_reg[stage]); - const float mx_block_X_amax = warp_reduce_max_broadcast(amax); - const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); - - out_gate_rowwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal_X * after_dgate_reg[stage]); - - // Only single thread writes the computed scaling factor - if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; - const int global_scales_offset_X = - scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent_X; - } + after_gate_colwise[i] = after_gate_elt; } - float amax = fabsf(after_dact_reg[stage]); - const float mx_block_X_amax = warp_reduce_max_broadcast(amax); - const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); - - out_act_rowwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal_X * after_dact_reg[stage]); - - // Only single thread writes the computed scaling factor - if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; - const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent_X; + + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(after_act_elt); + if constexpr (IS_DGATED) { + cached_gate_sh[shmem_offset_colwise] = static_cast(after_gate_elt); + } } - } - if constexpr (USE_COLWISE_SCALING) { - __builtin_assume(thread_Y_mx_block_amax >= 0); - __builtin_assume(thread_Y_mx_block_amax_gate >= 0); - thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); - if constexpr (IS_DGATED) { - thread_Y_mx_block_amax_gate = - fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage])); + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + + if (!out_of_bounds) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); + } } } - } - - if constexpr (USE_COLWISE_SCALING) { - const bool row_out_of_bounds = (row_base >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); - if constexpr (IS_DGATED) { - // Colwise max reduction of the amax element - if (tid_Y > 0) { - stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate; + if constexpr (ONLY_COLWISE_SCALING) { + // Threads, whose id along Y-dim is 0, don't need to store to shared memory, + // as they manage the columwise reduction of the amax + if (tid_Y_colwise > 0) { + subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_act; } __syncthreads(); - if (tid_Y == 0) { + if (tid_Y_colwise == 0) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { - thread_Y_mx_block_amax_gate = - fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]); + for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) { + const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise]; + __builtin_assume(thread_amax_act >= 0); + __builtin_assume(other_thread_amax >= 0); + + thread_amax_act = fmaxf(thread_amax_act, other_thread_amax); } - stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax + subamax_colwise_buff[0][tid_X_colwise] = thread_amax_act; } __syncthreads(); - const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + // All threads read the reduced amax (ACT) + thread_amax_act = subamax_colwise_buff[0][tid_X_colwise]; + + if constexpr (IS_DGATED) { + // Make sure the previous read of the ACT values has been completed, + // so the data are not rewritten + __syncthreads(); + if (tid_Y_colwise > 0) { + subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_gate; + } + __syncthreads(); + if (tid_Y_colwise == 0) { +#pragma unroll + for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) { + const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise]; + __builtin_assume(thread_amax_gate >= 0); + __builtin_assume(other_thread_amax >= 0); + + thread_amax_gate = fmaxf(thread_amax_gate, other_thread_amax); + } + subamax_colwise_buff[0][tid_X_colwise] = thread_amax_gate; + } + __syncthreads(); - // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section - if constexpr (!USE_ROWWISE_SCALING) { - __builtin_assume(mx_block_Y_amax >= 0); + // All threads read the reduced amax (GATE) + thread_amax_gate = subamax_colwise_buff[0][tid_X_colwise]; } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent_act = + ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + const size_t scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows; + const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise; - const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - - // Only single thread writes the computed scaling factor - // Also assuming one iteration covers exactly 32 rows - if ((tid_Y == 0) && !out_of_bounds) { - const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { + scales_colwise[scale_idx] = biased_exponent_act; + } + + float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); + float block_scale_inverse_gate; + + if constexpr (IS_DGATED) { + const e8m0_t biased_exponent_gate = + ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + // const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2; + const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; + if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { + scales_colwise[scale_idx_gate] = biased_exponent_gate; } + block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); + } +// 3. Scale elements #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - out_gate_colwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal * after_dgate_reg[stage]); + for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { + const size_t shmem_offset_elt = + shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; + if constexpr (IS_DGATED) { + OType2 out_pair; + ptx::floatx2 in_pair = {after_act_colwise[i], after_gate_colwise[i]}; + const ptx::floatx2 block_scale_inverse_2x_pair = {block_scale_inverse_act, + block_scale_inverse_gate}; + ptx::mul_cvt_2x(out_pair, in_pair, block_scale_inverse_2x_pair); + out_act_colwise_sh[shmem_offset_elt] = out_pair.x; + out_gate_colwise_sh[shmem_offset_elt] = out_pair.y; + } else { + const float scaled_out_act = block_scale_inverse_act * after_act_colwise[i]; + out_act_colwise_sh[shmem_offset_elt] = static_cast(scaled_out_act); } } - // Colwise max reduction of the amax element - if (tid_Y > 0) { - stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax; - } - __syncthreads(); - if (tid_Y == 0) { + } + + if constexpr (ROWWISE_SCALING) { + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + + float thread_amax_act = 0.0f; + float thread_amax_gate = 0.0f; + + Vec in_cached_act[WAVES]; + Vec in_cached_gate[WAVES]; + + float after_act_rowwise[SCALE_DIM_X]; + float after_gate_rowwise[SCALE_DIM_X]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x_act = {static_cast(0.0f), static_cast(0.0f)}; + IType2 thread_amax_2x_gate = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached_act[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + in_cached_gate[w].load_from(&cached_gate_sh[shmem_offset_rowwise]); + } + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { - thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]); + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(in_cached_act[w].data.elt[e])); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(in_cached_gate[w].data.elt[e])); + } + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x_act = {in_cached_act[w].data.elt[e], + in_cached_act[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x_act, thread_amax_2x_act, in_cached_2x_act); + if constexpr (IS_DGATED) { + const IType2 in_cached_2x_gate = {in_cached_gate[w].data.elt[e], + in_cached_gate[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x_gate, thread_amax_2x_gate, in_cached_2x_gate); + } + } + } + } } - stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax - } - __syncthreads(); + if constexpr (!std::is_same_v) { + thread_amax_act = static_cast( + __hmax(__habs(thread_amax_2x_act.x), __habs(thread_amax_2x_act.y))); + if constexpr (IS_DGATED) { + thread_amax_gate = static_cast( + __hmax(__habs(thread_amax_2x_gate.x), __habs(thread_amax_2x_gate.y))); + } + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + Vec in_grad; + Vec in_act; + Vec in_gate; - // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section - if constexpr (!USE_ROWWISE_SCALING) { - __builtin_assume(mx_block_Y_amax >= 0); + in_act.load_from(&in_act_sh[shmem_offset_rowwise]); + in_gate.load_from(&in_gate_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + in_grad.load_from(&in_grad_sh[shmem_offset_rowwise]); + } + +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + + float act_elt = static_cast(in_act.data.elt[e]); + float gate_elt = static_cast(in_gate.data.elt[e]); + float after_act_elt; + float after_gate_elt; + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad.data.elt[e]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_act_elt = dact_x * grad_elt * gate_elt; + after_gate_elt = act_x * grad_elt; + after_act_rowwise[j] = after_act_elt; + after_gate_rowwise[j] = after_gate_elt; + } else { + after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_act_rowwise[j] = after_act_elt; + } + + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + after_act_elt = static_cast(static_cast(after_act_elt)); + if constexpr (IS_DGATED) { + after_gate_elt = static_cast(static_cast(after_gate_elt)); + } + } + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); + } + } + } + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent_act = + ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); + const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const size_t stage_scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows; + const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise; + if (!out_of_bounds_rowwise) { + scales_rowwise[scale_idx] = biased_exponent_act; } - const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - - // Only single thread writes the computed scaling factor - // Also assuming one iteration covers exactly 32 rows - if ((tid_Y == 0) && !out_of_bounds) { - const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + const float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); + const ptx::floatx2 block_scale_inverse_2x_act = {block_scale_inverse_act, + block_scale_inverse_act}; + + float block_scale_inverse_gate; + ptx::floatx2 block_scale_inverse_2x_gate; + if constexpr (IS_DGATED) { + const e8m0_t biased_exponent_gate = + ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; + if (!out_of_bounds_rowwise) { + scales_rowwise[scale_idx_gate] = biased_exponent_gate; + } + block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); + block_scale_inverse_2x_gate = {block_scale_inverse_gate, block_scale_inverse_gate}; } +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out_act; + Vec out_gate; #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - out_act_colwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal * after_dact_reg[stage]); + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in_act; + OType2 &out_act_pair = reinterpret_cast(out_act.data.elt[e]); + + if constexpr (IS_CACHED_ACT_OP) { + in_act.x = in_cached_act[w].data.elt[2 * e]; + in_act.y = in_cached_act[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in_act.x = after_act_rowwise[j]; + in_act.y = after_act_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_act_pair, in_act, block_scale_inverse_2x_act); + + if constexpr (IS_DGATED) { + IType2 in_gate; + OType2 &out_gate_pair = reinterpret_cast(out_gate.data.elt[e]); + + if constexpr (IS_CACHED_ACT_OP) { + in_gate.x = in_cached_gate[w].data.elt[2 * e]; + in_gate.y = in_cached_gate[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in_gate.x = after_gate_rowwise[j]; + in_gate.y = after_gate_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); + } + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]); + } } - } // endif USE_COLWISE_SCALING + } - // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + // Wait for shared memory writes to be visible to TMA engine. ptx::fence_proxy_async_shared_cta(); __syncthreads(); // After syncthreads, writes by all threads are visible to TMA engine. // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { - const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t buff_offset = buff * BUFF_DIM; - // dGeLU - if constexpr (USE_ROWWISE_SCALING) { + if constexpr (ROWWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_act_rowwise_sh_curr)); - + reinterpret_cast(&tensor_map_output_act_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_act_rowwise_sh[buff_offset])); if constexpr (IS_DGATED) { - // dGate ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_rowwise_sh_curr)); + reinterpret_cast(&tensor_map_output_gate_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_gate_rowwise_sh[buff_offset])); } } - - // dGeLU - if constexpr (USE_COLWISE_SCALING) { + if constexpr (COLWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_act_colwise_sh_curr)); - + reinterpret_cast(&tensor_map_output_act_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_act_colwise_sh[buff_offset])); if constexpr (IS_DGATED) { - // dGate ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_colwise_sh_curr)); + reinterpret_cast(&tensor_map_output_gate_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_gate_colwise_sh[buff_offset])); } } // Create a "bulk async-group" out of the previous bulk copy operation. ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); } } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - // Destroy the barriers. This invalidates the memory region of the barrier. - // If further computations were to take place in the kernel, this allows the - // memory location of the shared memory barrier to be reused. - if (is_master_thread) { -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_invalid(&mbar[it]); - } - } + parity ^= 1; + destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +} // namespace mxfp8_kernel template void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, cudaStream_t stream) { + checkCuDriverContext(stream); + if (output->has_data()) { NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); } @@ -779,28 +947,28 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t out_act_mem = buff_size_aligned_out; const size_t out_gate_mem = buff_size_aligned_out; - // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); - const size_t shmem_size = ALIGNMENT_SIZE + grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem); // + mbar_mem; + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( cast_fp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); cast_fp8_gated_kernel <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols);); // NOLINT(*) - ); // NOLINT(*) + cols); + NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) } #endif //#ifdef __HIP_PLATFORM_AMD__ @@ -808,6 +976,8 @@ template void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, cudaStream_t stream) { + checkCuDriverContext(stream); + const bool USE_ROWWISE_SCALING = output->has_data(); const bool USE_COLWISE_SCALING = output->has_columnwise_data(); @@ -818,16 +988,48 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); } - // TODO: Make more general - const size_t scale_dim_X_rowwise = USE_ROWWISE_SCALING ? 32 : 1; - const size_t scale_dim_Y_colwise = USE_COLWISE_SCALING ? 32 : 1; +#ifndef __HIP_PLATFORM_AMD__ + ScalingType scaling_type; + if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!USE_ROWWISE_SCALING) && USE_COLWISE_SCALING) { + scaling_type = ScalingType::COLWISE; + } else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + scaling_type = ScalingType::BIDIMENSIONAL; + } +#endif const size_t rows = gated_input.flat_first_dim(); const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; +#ifdef __HIP_PLATFORM_AMD__ + constexpr size_t TMA_SHMEM_ALIGNMENT = ALIGNMENT_SIZE; + + constexpr size_t BUFF_DIM_Y = BUFFER_DIM_Y; + constexpr size_t BUFF_DIM_X = BUFFER_DIM_X; + constexpr size_t BUFFS_NUM = BUFFERS_NUM; + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); +#else + + constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; + constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; + constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; + + const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); + + constexpr size_t THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE; + constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE; + const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE) + ? THREADS_PER_CHUNK_COLWISE + : THREADS_PER_CHUNK_NON_COLWISE; +#endif + + const dim3 grid(blocks_X, blocks_Y); + const dim3 block_size(THREADS_PER_CHUNK); size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; @@ -837,116 +1039,166 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out e8m0_t *const scales_colwise_ptr = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const dim3 block_dim(THREADS_PER_CHUNK); - const dim3 grid_dim(blocks_X, blocks_Y); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_Y_colwise, SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_X_rowwise, SCALE_DIM_X, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - gated_input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, #ifdef __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_SWITCH_CONDITION( - !(cols % (32 * sizeof(IType))), IS_ALIGNED, - const IType *tensor_map_grad = IS_DGATED ? reinterpret_cast(grad.data.dptr) : nullptr; - const IType *tensor_map_input_act = reinterpret_cast(gated_input.data.dptr); - const IType *tensor_map_input_gate = reinterpret_cast(gated_input.data.dptr) + cols; - OType *tensor_map_output_act_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) : nullptr; - OType *tensor_map_output_gate_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) + cols : nullptr; - OType *tensor_map_output_act_colwise = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; - OType *tensor_map_output_gate_colwise = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) + cols : nullptr; + const IType *tensor_map_grad = IS_DGATED ? reinterpret_cast(grad.data.dptr) : nullptr; + const IType *tensor_map_input_act = reinterpret_cast(gated_input.data.dptr); + const IType *tensor_map_input_gate = reinterpret_cast(gated_input.data.dptr) + cols; + OType *tensor_map_output_act_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) : nullptr; + OType *tensor_map_output_gate_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) + cols : nullptr; + OType *tensor_map_output_act_colwise = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; + OType *tensor_map_output_gate_colwise = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) + cols : nullptr; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; #else // #ifdef __HIP_PLATFORM_AMD__ - alignas(64) CUtensorMap tensor_map_grad{}; - alignas(64) CUtensorMap tensor_map_input_act{}; - alignas(64) CUtensorMap tensor_map_input_gate{}; - alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_act_colwise{}; - alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; - - if constexpr (IS_DGATED) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype())); - } - - const uint32_t tensor_stride_elems = output_cols; - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, - typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, - typeToNumBits(gated_input.dtype())); - - if (USE_ROWWISE_SCALING) { - create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, - typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, - typeToNumBits(output->dtype())); - } - - if (USE_COLWISE_SCALING) { - create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, - rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, - 0, typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, - rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, - cols, typeToNumBits(output->dtype())); - } -#endif // #ifdef __HIP_PLATFORM_AMD__ + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_act_colwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; - const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; + if constexpr (IS_DGATED) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); + } + + const uint32_t tensor_stride_elems = output_cols; + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols * 2, 0, input_type_bit_size); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols * 2, cols, input_type_bit_size); + + if (USE_ROWWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, + output_type_bit_size); + create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, + output_type_bit_size); + } - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - size_t out_mem = out_act_mem + out_gate_mem; - if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + if (USE_COLWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, + output_type_bit_size); + create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows, + cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, + output_type_bit_size); + } +#endif // #ifdef __HIP_PLATFORM_AMD__ - // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); - // const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem + mbar_mem; + const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X; + const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + const size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + const size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem; + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - (const void*)cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + + const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - cast_mxfp8_gated_kernel - <<>>( + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (USE_COLWISE_SCALING ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (USE_ROWWISE_SCALING ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION(!(cols % (32 * sizeof(IType))), IS_ALIGNED, { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + }))); // NOLINT(*) +#else + switch (scaling_type) { + case ScalingType::ROWWISE: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + mxfp8_kernel::cast_mxfp8_gated_kernel + <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise);); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) -#ifdef __HIP_PLATFORM_AMD__ - ); // NOLINT(*) + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + case ScalingType::COLWISE: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + mxfp8_kernel::cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + case ScalingType::BIDIMENSIONAL: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + mxfp8_kernel::cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } #endif + ); // NOLINT(*) + ); // NOLINT(*) } template @@ -1026,9 +1278,6 @@ template void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, cudaStream_t stream) { -#ifndef __HIP_PLATFORM_AMD__ - checkCuDriverContext(stream); -#endif constexpr bool allow_empty = false; CheckInputTensor(gated_input, "gated_input"); CheckOutputTensor(*output, "output", allow_empty); diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 468d31690..b7c4cf837 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -36,36 +36,25 @@ namespace transformer_engine { #ifndef __HIP_PLATFORM_AMD__ -constexpr size_t MXFP8_CHUNK_DIM_Y = 64; -constexpr size_t MXFP8_CHUNK_DIM_X = 64; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; -constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; -constexpr size_t MXFP8_BUFFERS_NUM = 2; -constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; -static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM); - -constexpr size_t ELEMS_PER_THREAD = 16; -constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported -constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 -constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 -constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 - -constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = - MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16 -constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = - MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4 -constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 -constexpr size_t MXFP8_BUFF_STAGES_NUM = - MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 -constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 -static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM); +namespace mxfp8_kernel { + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 template -__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING, + bool COLWISE_SCALING, size_t CHUNK_DIM_Y, size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK> +__global__ void __launch_bounds__(THREADS_PER_CHUNK) cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_act_input, const __grid_constant__ CUtensorMap tensor_map_output_rowwise, @@ -75,201 +64,343 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - if (noop != nullptr && noop[0] == 1.0f) return; + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } } + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; + static_assert(BUFF_DIM_Y == 32); + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + static_assert(STAGES >= 1); + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X; + const size_t tid_Y_colwise = 0; + const size_t tid_X_colwise = threadIdx.x; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const size_t thread_offset_Y_colwise = tid_Y_colwise; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + OType *out_rowwise_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); - constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; - constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; - constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; - - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 - constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = - SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 - constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = - SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 - - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 - constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = - SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 - constexpr size_t SCALES_COLWISE_PER_BLOCK_X = - SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 - - constexpr size_t THREADS_PER_SCALE_X_ROWWISE = - DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 - constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 - - const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; - const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; - const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; - const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; - const int scales_colwise_block_offset_X = blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; - - const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; - const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; - // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; - const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; - - const int thread_offset_Y = tid_rowwise_Y; - const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; - // const int thread_offset_X_colwise = tid_colwise_X; - - const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y; - const int dbias_rowwise_block_offset_X = - blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise; - const int dbias_colwise_offset_Y = blockIdx.y; - const int dbias_colwise_block_offset_X = - blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X; - const int dbias_stride = cols; - - Vec partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; - float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X]; + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; if constexpr (IS_DBIAS) { - if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { -#pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { - partial_dbias_rowwise[i].clear(); - } - } else { #pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { - partial_dbias_colwise[i] = 0; - } + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; } } - // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - __shared__ alignas(128) IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - __shared__ alignas(128) IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - __shared__ alignas(128) - OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - __shared__ alignas(128) - OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - - constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - float block_amax = 0; + float block_amax = 0.0f; // Initialize shared memory barrier with the number of threads participating in the barrier. #pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + __shared__ alignas(8) uint64_t mbar[STAGES]; - initialize_barriers(mbar, is_master_thread); + initialize_barriers(mbar, is_master_thread); int parity = 0; -#pragma unroll - for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { - const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; - const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; - const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; - const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], + &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } - const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; - const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; - const int scales_rowwise_chunk_offset_Y = - scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; - const int scales_rowwise_chunk_offset_X = - scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; - const int scales_colwise_chunk_offset_Y = - scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; - const int scales_colwise_chunk_offset_X = - scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); -#pragma unroll - for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { - const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; - const int chunk_stage_offset_X = chunk_offset_X; + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, - chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, - &mbar[prefetch_buff], is_master_thread); + copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); } else { - copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], - is_master_thread); + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); } } + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], parity); + + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); #pragma unroll - for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { - const int buff = iter % MXFP8_BUFFERS_NUM; - const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; - const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; - - if (next_iter < MXFP8_ITERATIONS) { - const int next_buff = next_iter % MXFP8_BUFFERS_NUM; - const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, - chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, - &mbar[next_iter], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); } - } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } - ptx::fence_proxy_async_shared_cta(); + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - if constexpr (USE_ROWWISE_SCALING) { - Vec in; - Vec act_in; - Vec out_c; + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + const size_t scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; - const int iteration_scale_rowwise_offset_Y = - scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; +// 3. Scale elements #pragma unroll - for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X_rowwise; + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = (row >= rows); + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } - in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]); - } + if constexpr (ROWWISE_SCALING) { + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; - float thread_amax = 0; - float in_compute[ELEMS_PER_THREAD]; + // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY + Vec in_IType[WAVES]; + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; - float elt = static_cast(in.data.elt[j]); + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); if constexpr (IS_ACT) { elt = OP(elt, {}); } if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[j]); + float act_in_elt = static_cast(act_in.data.elt[e]); elt *= OP(act_in_elt, {}); } - if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - if (!out_of_bounds) { - partial_dbias_rowwise[chunk_X].data.elt[j] += elt; - } - } - in_compute[j] = elt; - if constexpr (IS_ACT || IS_DACT) { + // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); if (!out_of_bounds) { thread_amax = fmaxf(thread_amax, fabsf(elt)); } @@ -277,196 +408,141 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) // If no activation, elt is 0 so we can safely do this thread_amax = fmaxf(thread_amax, fabsf(elt)); } + in_compute_rowwise[j] = elt; } - - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); - - const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); - const e8m0_t biased_exponent = - float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); - - // Only single thread writes the computed scaling factor - if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; - const int global_scales_offset_X = - scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent; - } - - const float block_scale_inverse = exp2f_rcp(biased_exponent); - -#pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); - } - out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); } } - if constexpr (USE_COLWISE_SCALING) { - const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); - float in_compute[SCALE_DIM_Y]; + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const size_t stage_scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent; - float amax = 0; -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - const size_t row = row_base + i; - const bool row_out_of_bounds = (row >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - float elt = static_cast(in_sh[buff][i][tid_colwise_X]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[buff][i][tid_colwise_X]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - if (!out_of_bounds) { - partial_dbias_colwise[chunk_X] += elt; - } - } - in_compute[i] = elt; - if constexpr (IS_ACT || IS_DACT) { - if (!out_of_bounds) { - amax = fmaxf(amax, fabsf(elt)); - } + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; } else { - // If no activation, elt is 0 so we can safely do this - amax = fmaxf(amax, fabsf(elt)); + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_sh[shmem_offset_rowwise]); + } + } - __builtin_assume(block_amax >= 0); - __builtin_assume(amax >= 0); - block_amax = fmaxf(block_amax, amax); - - const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); - const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. - const float block_scale_inverse = exp2f_rcp(biased_exponent); -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - out_colwise_sh[buff][i][tid_colwise_X] = - static_cast(in_compute[i] * block_scale_inverse); - } + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_sh[buff_offset])); } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - if constexpr (USE_ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_rowwise_sh[buff])); - } - if constexpr (USE_COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_colwise_sh[buff])); - } - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_sh[buff_offset])); } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - parity ^= 1; + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } } - if constexpr (IS_DBIAS) { - if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X; - constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; - constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; - __shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD]; - - if (tid_rowwise_Y > 0) { -#pragma unroll - for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { - partial_dbias_rowwise[c].store_to( - &shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]); - } - } - __syncthreads(); + parity ^= 1; - if (tid_rowwise_Y == 0) { -#pragma unroll - for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { - Vec other_row_dbias; - const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; - const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; + if constexpr (IS_DBIAS) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] + // HEIGHT = THREADS_Y + // WIDTH = THREADS_X * (SCALE_DIM_X + 1) + // Added extra 1-element padding per thread_X to reduce bank conflicts + float *partial_dbias_rowwise = reinterpret_cast(dshmem); - const int left_bound = dbias_rowwise_offset_X; - const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1; + constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + const size_t shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll - for (int i = 0; i < Y; ++i) { - other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; - } - } - - // Vectorized store when all elements are inside the boundaries - if (right_bound < cols) { - partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]); - } else if (left_bound < cols && right_bound >= cols) { - // Element-by-element store when some elements cross the boundaries - const int in_bound_elts_count = cols - left_bound; - partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0, - in_bound_elts_count); - } + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const size_t shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; } } - } else { + __syncthreads(); #pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { - const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; - const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; - const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); - if (!col_out_of_bounds) { - dbias_workspace[dbias_offset] = partial_dbias_colwise[i]; - } + for (int i = 0; i < THREADS_Y; ++i) { + // Add extra element offset per MXFP8 scaling block [1x32] + const size_t scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; } } + const size_t dbias_stride = cols; + const size_t dbias_offset_Y = blockIdx.y; + const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; + const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; + } } if (amax_ptr != nullptr) { const int warp_id = threadIdx.x / THREADS_PER_WARP; // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); + block_amax = reduce_max(block_amax, warp_id); } if (is_master_thread && amax_ptr != nullptr) { atomicMaxFloat(amax_ptr, block_amax); } - destroy_barriers(mbar, is_master_thread); + destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +} // namespace mxfp8_kernel constexpr size_t FP8_CHUNK_DIM_Y = 128; constexpr size_t FP8_CHUNK_DIM_X = 128; @@ -495,19 +571,19 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const size_t cols) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - const int block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; + const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; - const int tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; - const int tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; + const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; + const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; - const int thread_offset_Y = tid_Y; - const int thread_offset_X = tid_X; + const size_t thread_offset_Y = tid_Y; + const size_t thread_offset_X = tid_X; - const int dbias_offset_Y = blockIdx.y + tid_Y; - const int my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; + const size_t dbias_offset_Y = blockIdx.y + tid_Y; + const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; const bool col_out_of_bounds = my_column >= cols; - const int dbias_stride = cols; + const size_t dbias_stride = cols; float partial_dbias = 0.f; @@ -515,11 +591,14 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(128) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(128) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; + constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; const bool is_master_thread = (threadIdx.x == 0); @@ -531,13 +610,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) int parity = 0; - const int chunk_offset_Y = block_offset_Y; - const int chunk_offset_X = block_offset_X; + const size_t chunk_offset_Y = block_offset_Y; + const size_t chunk_offset_X = block_offset_X; #pragma unroll for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { - const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; - const int chunk_stage_offset_X = chunk_offset_X; + const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; + const size_t chunk_stage_offset_X = chunk_offset_X; if constexpr (IS_DACT) { copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, @@ -552,13 +631,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) #pragma unroll for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { - const int buff = iter % FP8_BUFFERS_NUM; - const int next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; + const size_t buff = iter % FP8_BUFFERS_NUM; + const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; if (next_iter < FP8_ITERATIONS) { - const int next_buff = next_iter % FP8_BUFFERS_NUM; - const int chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; + const size_t next_buff = next_iter % FP8_BUFFERS_NUM; + const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; if constexpr (IS_DACT) { copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, @@ -575,9 +654,9 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) #pragma unroll for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; + const size_t stage_offset_Y = stage; + const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; + const size_t shmem_offset_x = thread_offset_X; const size_t row = row_base + shmem_offset_y; const bool row_out_of_bounds = row >= rows; const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; @@ -616,8 +695,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { - const int chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; + const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output), chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); @@ -635,8 +714,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) parity ^= 1; if constexpr (IS_DBIAS) { - const int dbias_offset_X = my_column; - const int dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; + const size_t dbias_offset_X = my_column; + const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; if (!col_out_of_bounds) { dbias_workspace[dbias_offset] = partial_dbias; } @@ -678,7 +757,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - const int block_offset = blockIdx.x * ELEMS_PER_BLOCK; + const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK; const IType *input = input_ptr + block_offset; OType *output = output_ptr + block_offset; @@ -686,11 +765,11 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(128) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; - __shared__ alignas(128) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; - constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; - constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; + constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; + constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; const bool is_master_thread = (threadIdx.x == 0); @@ -706,12 +785,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) #pragma unroll for (int iter = 0; iter < ITERATIONS; ++iter) { - const int buff = iter % SHMEM_BUFFERS; - const int it_offset = iter * SHMEM_DIM; + const size_t buff = iter % SHMEM_BUFFERS; + const size_t it_offset = iter * SHMEM_DIM; - const int next_iter = iter + 1; - const int next_buff = next_iter % SHMEM_BUFFERS; - const int next_iter_offset = next_iter * SHMEM_DIM; + const size_t next_iter = iter + 1; + const size_t next_buff = next_iter % SHMEM_BUFFERS; + const size_t next_iter_offset = next_iter * SHMEM_DIM; if (next_iter < ITERATIONS) { copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, @@ -725,7 +804,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) #pragma unroll for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { - const int shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; + const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; float elt = static_cast(in_sh[buff][shmem_offset]); if constexpr (IS_ACT) { elt = OP(elt, {}); @@ -779,12 +858,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; template __global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) - reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, const int rows, - const int cols) { + reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, + const size_t rows, const size_t cols) { using ComputeVec = Vec; using OutputVec = Vec; - const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; if (thread_id * nvec >= cols) { return; @@ -815,8 +894,8 @@ __global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) template void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, cudaStream_t stream) { - constexpr int reduce_dbias_store_bytes = 8; // stg.64 - constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 + constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); @@ -824,6 +903,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, reduce_dbias_kernel <<>>( reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } #ifndef __HIP_PLATFORM_AMD__ @@ -856,6 +936,7 @@ static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream cast_fp8_1D_kernel<<>>( input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -919,6 +1000,7 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols); + NVTE_CHECK_CUDA(cudaGetLastError()); if constexpr (IS_DBIAS) { reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); @@ -932,12 +1014,13 @@ template has_data(); - bool use_colwise_scaling = output->has_columnwise_data(); #ifndef __HIP_PLATFORM_AMD__ + using namespace mxfp8_kernel; checkCuDriverContext(stream); -#endif // #ifndef __HIP_PLATFORM_AMD__ +#endif + bool use_rowwise_scaling = output->has_data(); + bool use_colwise_scaling = output->has_columnwise_data(); NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); @@ -950,16 +1033,30 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, } CheckNoopTensor(*noop, "cast_noop"); - // TODO: Make more general - const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; - const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; - const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); - const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); - const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); + +#ifdef __HIP_PLATFORM_AMD__ + constexpr size_t CHUNK_DIM_Y = MXFP8_CHUNK_DIM_Y; + constexpr size_t CHUNK_DIM_X = MXFP8_CHUNK_DIM_X; + constexpr size_t THREADS_PER_CHUNK = MXFP8_THREADS_PER_CHUNK; +#else + constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); + + constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; + + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; +#endif + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = @@ -972,6 +1069,17 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t dbias_rows = blocks_Y; const size_t dbias_cols = cols; +#ifndef __HIP_PLATFORM_AMD__ + ScalingType scaling_type; + if (use_rowwise_scaling && (!use_colwise_scaling)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + scaling_type = ScalingType::COLWISE; + } else if (use_rowwise_scaling && use_colwise_scaling) { + scaling_type = ScalingType::BIDIMENSIONAL; + } +#endif + if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); @@ -986,23 +1094,22 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); - const dim3 block(MXFP8_THREADS_PER_CHUNK); - const dim3 grid(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_Y_colwise, SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_X_rowwise, SCALE_DIM_X, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, #ifdef __HIP_PLATFORM_AMD__ + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, TRANSFORMER_ENGINE_SWITCH_CONDITION( !(cols % (32 * sizeof(IType))), IS_ALIGNED, - cast_mxfp8_2D_kernel<<>>( + cast_mxfp8_2D_kernel + <<>>( reinterpret_cast(input.data.dptr), (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, reinterpret_cast(output->data.dptr), @@ -1010,52 +1117,108 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, scales_rowwise_ptr, scales_colwise_ptr, reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + ))); // NOLINT(*) #else // #ifdef __HIP_PLATFORM_AMD__ - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, - MXFP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, - MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - typeToNumBits(input.dtype())); - } + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, input_type_bit_size); + } - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, - MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - typeToNumBits(output->dtype())); - } + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, output_type_bit_size); + } - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, - cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - typeToNumBits(output->dtype())); - } + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); + } - cast_mxfp8_2D_kernel<<>>( + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + switch (scaling_type) { + case ScalingType::ROWWISE: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + cast_mxfp8_2D_kernel + <<>>( tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + case ScalingType::COLWISE: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + case ScalingType::BIDIMENSIONAL: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } #endif // #ifdef __HIP_PLATFORM_AMD__ - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - - }); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) -#ifdef __HIP_PLATFORM_AMD__ - ); // NOLINT(*) -#endif + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) } namespace detail { @@ -1152,8 +1315,8 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) { bool dimensions_supported_by_TMA(const Tensor *const t) { const size_t cols = t->flat_last_dim(); - constexpr int TMA_bytes = 16; - const int alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); + constexpr size_t TMA_bytes = 16; + const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); return cols % alignment_requirement == 0; } #endif //#ifndef __HIP_PLATFORM_AMD__ @@ -1171,8 +1334,8 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons case NVTE_DELAYED_TENSOR_SCALING: { if (!IS_DBIAS && !IS_DACT) { if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_gmem_alignment) && - is_aligned_tensor_data(*output, TMA_gmem_alignment)) { + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { // Aligned AND FP8 cast_fp8_1D(input, output, stream); } else { @@ -1181,9 +1344,9 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons } } else if (!IS_DBIAS && IS_DACT) { if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_gmem_alignment) && - is_aligned_tensor_data(*output, TMA_gmem_alignment) && - is_aligned_tensor_data(*act_input, TMA_gmem_alignment)) { + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { // Aligned AND FP8 (+dAct) cast_fp8_2D(input, act_input, output, dbias, workspace, stream); @@ -1216,7 +1379,7 @@ void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { // zhongboz: should we just ignore IS_ACT here? NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + - " on GPU with compute capability < 10.0."); + " or IS_DBIAS=true" + " on GPU with compute capability < 10.0."); } switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { @@ -1336,7 +1499,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor.data, stream); break; } case NVTE_BLOCK_SCALING_1D: { @@ -1364,10 +1528,10 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; } - quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, - output_tensor->columnwise_scale_inv, output_tensor->data, - output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, stream); + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor.data, stream); break; } #endif diff --git a/transformer_engine/common/util/cuda_driver.cpp b/transformer_engine/common/util/cuda_driver.cpp index d4835b611..cfc39403a 100644 --- a/transformer_engine/common/util/cuda_driver.cpp +++ b/transformer_engine/common/util/cuda_driver.cpp @@ -57,6 +57,22 @@ void *get_symbol(const char *symbol, int cuda_version) { } #endif +void ensure_context_exists() { + static thread_local bool need_check = []() { + CUcontext context; + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context); + if (context == nullptr) { + // Add primary context to context stack + CUdevice device; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device()); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device); + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device); + } + return false; + }(); +} + } // namespace cuda_driver } // namespace transformer_engine diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index f131bab45..32ed2ec4f 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -41,6 +41,14 @@ inline CUresult call(const char *symbol, ArgTs... args) { return (*func)(args...); } +/*! \brief Ensure that the calling thread has a CUDA context + * + * Each thread maintains a stack of CUDA contexts. If the calling + * thread has an empty stack, the primary context is added to the + * stack. + */ +void ensure_context_exists(); + } // namespace cuda_driver } // namespace transformer_engine diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index 8f0a9730b..aaeb169b1 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -76,8 +76,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // const int thread_offset_X_colwise = tid_colwise_X; // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - __shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; - __shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; constexpr int transaction_size = shmem_buff_size; @@ -158,7 +158,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; const e8m0_t biased_exponent = scales_ptr[scale_idx]; - const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + const float block_scale = ptx::exp2f(biased_exponent); if constexpr (USE_ROWWISE_SCALING) { Vec in; @@ -334,6 +334,7 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s #ifdef __HIP_PLATFORM_AMD__ ); // NOLINT(*) #endif + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace dequantization diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 6dbcd6974..6ab5eb958 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -6,7 +6,8 @@ * See LICENSE for license information. ************************************************************************/ -#pragma once +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #include #ifdef __HIP_PLATFORM_AMD__ @@ -17,8 +18,13 @@ #endif // __HIP_PLATFORM_AMD__ #include +#ifdef NVTE_WITH_CUBLASMP +#include +#endif // NVTE_WITH_CUBLASMP + #include #include +#include #include "../util/string.h" @@ -102,3 +108,17 @@ NVTE_ERROR("NVRTC Error: ", nvrtcGetErrorString(status_NVTE_CHECK_NVRTC)); \ } \ } while (false) + +#ifdef NVTE_WITH_CUBLASMP + +#define NVTE_CHECK_CUBLASMP(expr) \ + do { \ + const cublasMpStatus_t status = (expr); \ + if (status != CUBLASMP_STATUS_SUCCESS) { \ + NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \ + } \ + } while (false) + +#endif // NVTE_WITH_CUBLASMP + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index a1899d5b1..0d92b243a 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -35,6 +35,7 @@ struct MultiPaddingArgs { int padded_num_rows_list[kMaxTensorsPerKernel]; // Input matrix widths int row_length_list[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of CUDA blocks needed for each // tensor int block_range[kMaxTensorsPerKernel + 1]; // Number of tensors being processed by kernel @@ -247,6 +248,7 @@ void multi_padding(const std::vector input_list, std::vector o const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; multi_padding_kernel <<>>(kernel_args);); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); kernel_args.num_tensors = 0; } @@ -276,6 +278,7 @@ void multi_padding(const std::vector input_list, std::vector o const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; multi_padding_kernel <<>>(kernel_args);); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -321,6 +324,7 @@ void multi_unpadding(const std::vector input_list, std::vector const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; multi_unpadding_kernel <<>>(kernel_args);); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); kernel_args.num_tensors = 0; } @@ -348,6 +352,7 @@ void multi_unpadding(const std::vector input_list, std::vector const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; multi_unpadding_kernel <<>>(kernel_args);); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 55bc247f7..7c38a337b 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -104,6 +106,56 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3 #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +constexpr uint32_t FP32_MANTISSA_BITS = 23; +constexpr uint32_t FP32_EXPONENT_BIAS = 127; + +__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { + return (biased_exp == 0) ? 1 + : __int_as_float((254 - biased_exp) + << FP32_MANTISSA_BITS); // 127 - (biased_exp - 127) +} + +__device__ __forceinline__ float exp2f(e8m0_t biased_exp) { + return __int_as_float(biased_exp << FP32_MANTISSA_BITS); +} + +__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { +#ifdef __HIP_PLATFORM_AMD__ +#define __CUDA_ARCH_HAS_FEATURE__(x) 0 +#endif +#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ + (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) + uint16_t out; + asm volatile( + "{\n" + "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" + "}" + : "=h"(out) + : "f"(val)); + return *reinterpret_cast(&out); +#else + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (isnan(val)) { + return 0xFF; + } + if (isinf(val)) { + return 0xFE; + } + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +#endif +} + #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor @@ -169,6 +221,159 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() { asm volatile("fence.proxy.async.shared::cta;"); } +template +struct alignas(2 * sizeof(T)) FPx2 { + T x; + T y; +}; + +using floatx2 = FPx2; +using bf16x2 = FPx2; +using fp16x2 = FPx2; +using fp8e4m3x2 = FPx2; +using fp8e5m2x2 = FPx2; + +static_assert(sizeof(floatx2) == 8); +static_assert(sizeof(bf16x2) == 4); +static_assert(sizeof(fp16x2) == 4); +static_assert(sizeof(fp8e4m3x2) == 2); +static_assert(sizeof(fp8e5m2x2) == 2); + +// SIMD like "Fused" cast + multiplication (x2) +__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, + const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + "mul.f32x2 val_pair, %1, %2; \n\t" + "mov.b64 {val2,val1}, val_pair; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in, + const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + "mul.f32x2 val_pair, %1, %2; \n\t" + "mov.b64 {val2,val1}, val_pair; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_bf16; \n\t" + ".reg.b16 val2_bf16; \n\t" + "mov.b32 {val1_bf16, val2_bf16} , %1; \n\t" + "cvt.f32.bf16 val1, val1_bf16; \n\t" + "cvt.f32.bf16 val2, val2_bf16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_bf16; \n\t" + ".reg.b16 val2_bf16; \n\t" + "mov.b32 {val1_bf16, val2_bf16} , %1; \n\t" + "cvt.f32.bf16 val1, val1_bf16; \n\t" + "cvt.f32.bf16 val2, val2_bf16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_fp16; \n\t" + ".reg.b16 val2_fp16; \n\t" + "mov.b32 {val1_fp16, val2_fp16} , %1; \n\t" + "cvt.f32.f16 val1, val1_fp16; \n\t" + "cvt.f32.f16 val2, val2_fp16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_fp16; \n\t" + ".reg.b16 val2_fp16; \n\t" + "mov.b32 {val1_fp16, val2_fp16} , %1; \n\t" + "cvt.f32.f16 val1, val1_fp16; \n\t" + "cvt.f32.f16 val2, val2_fp16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void abs_max_2x(bf16x2 &dst, const bf16x2 &p1, const bf16x2 &p2) { + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;" + : "=r"(reinterpret_cast(dst)) + : "r"(reinterpret_cast(p1)), + "r"(reinterpret_cast(p2))); +} + +__device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const fp16x2 &p2) { + asm volatile("max.xorsign.abs.f16x2 %0, %1, %2;" + : "=r"(reinterpret_cast(dst)) + : "r"(reinterpret_cast(p1)), + "r"(reinterpret_cast(p2))); +} + #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // namespace ptx diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index be06e807e..b243a8a0b 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -54,7 +54,9 @@ transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \ + .value("EXTERNAL_BULK_OVERLAP_AG", \ + transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG); \ py::class_>(m, "CommOverlapCore", \ pybind11::module_local()) \ diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh index b8fee6862..a53fd51c5 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh @@ -1,23 +1,23 @@ /************************************************************************* - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ #pragma once +#include #include #include -#include -#include - -#include -#include "../common.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" +#include "common.h" #include "math.h" -#include "../util/rocm_vectorized_2d.cuh" +#include "ptx.cuh" +#include "rocm_vectorized_2d.cuh" +#include "transformer_engine/activation.h" +#include "transformer_engine/cast.h" +#include "vectorized_pointwise.h" +#include "utils.cuh" namespace transformer_engine { namespace gated_kernels { @@ -191,14 +191,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; } + // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 + if constexpr (!std::is_same_v) { + after_dact_reg[stage] = static_cast(static_cast(after_dact_reg[stage])); + if constexpr (IS_DGATED) { + after_dgate_reg[stage] = static_cast(static_cast(after_dgate_reg[stage])); + } + } + if constexpr (USE_ROWWISE_SCALING) { if constexpr (IS_DGATED) { // dgate float amax = fabsf(after_dgate_reg[stage]); const float mx_block_X_amax = warp_reduce_max_broadcast(amax); const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + ptx::float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = ptx::exp2f_rcp(biased_exponent_X); out_gate_rowwise_sh[shmem_idx] = static_cast(scale_reciprocal_X * after_dgate_reg[stage]); @@ -217,8 +225,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float amax = fabsf(after_dact_reg[stage]); const float mx_block_X_amax = warp_reduce_max_broadcast(amax); const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + ptx::float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = ptx::exp2f_rcp(biased_exponent_X); out_act_rowwise_sh[shmem_idx] = static_cast(scale_reciprocal_X * after_dact_reg[stage]); @@ -273,8 +281,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); + ptx::float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = ptx::exp2f_rcp(biased_exponent); // Only single thread writes the computed scaling factor // Also assuming one iteration covers exactly 32 rows @@ -319,8 +327,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); + ptx::float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = ptx::exp2f_rcp(biased_exponent); // Only single thread writes the computed scaling factor // Also assuming one iteration covers exactly 32 rows diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index d62350e0a..52a77733a 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -1,26 +1,32 @@ /************************************************************************* - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ #pragma once +#include #include #include -#include - -#include -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" +#include "common.h" #include "math.h" -#include "transformer_engine/transformer_engine.h" -#include "../util/rocm_vectorized_2d.cuh" +#include "ptx.cuh" +#include "rocm_vectorized_2d.cuh" +#include "transformer_engine/cast.h" +#include "transpose/cast_transpose.h" +#include "vectorized_pointwise.h" +#include "utils.cuh" namespace transformer_engine { +// Forward declaration, definition is in cast_kernels.cuh +template +void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, + Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream); + + constexpr size_t MXFP8_CHUNK_DIM_Y = 64; constexpr size_t MXFP8_CHUNK_DIM_X = 64; constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; @@ -209,6 +215,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) partial_dbias_rowwise[chunk_X].data.elt[j] += elt; } } + // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } in_compute[j] = elt; if (!out_of_bounds) { thread_amax = fmaxf(thread_amax, fabsf(elt)); @@ -221,7 +231,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); const e8m0_t biased_exponent = - float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation + ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation // Only single thread writes the computed scaling factor if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { @@ -234,7 +244,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) scales_rowwise[scale_idx] = biased_exponent; } - const float block_scale_inverse = exp2f_rcp(biased_exponent); + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); #pragma unroll for (int j = 0; j < ELEMS_PER_THREAD; j++) { @@ -268,6 +278,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) partial_dbias_colwise[chunk_X] += elt; } } + // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } in_compute[i] = elt; if (!out_of_bounds) { amax = fmaxf(amax, fabsf(elt)); @@ -278,7 +292,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __builtin_assume(amax >= 0); block_amax = fmaxf(block_amax, amax); - const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation + const e8m0_t biased_exponent = ptx::float_to_e8m0(amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; @@ -286,7 +300,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; scales_colwise[scale_idx] = biased_exponent; - const float block_scale_inverse = exp2f_rcp(biased_exponent); + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); #pragma unroll for (int i = 0; i < SCALE_DIM_Y; i++) { out_colwise_sh[i][tid_colwise_X] = @@ -540,4 +554,5 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso } } + } // namespace transformer_engine diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index ae5cb4bbd..398e4c0ad 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -1,26 +1,26 @@ /************************************************************************* - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ #pragma once +#include #include #include -#include - -#include #include -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" +#include "common.h" #include "math.h" +#include "ptx.cuh" +#include "rocm_vectorized_2d.cuh" #include "transformer_engine/activation.h" +#include "transformer_engine/cast.h" +#include "transpose/cast_transpose.h" #include "transformer_engine/transpose.h" -#include "../util/rocm_vectorized_2d.cuh" +#include "utils.cuh" +#include "vectorized_pointwise.h" namespace transformer_engine { @@ -102,7 +102,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; const e8m0_t biased_exponent = scales_ptr[scale_idx]; - const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + const float block_scale = ptx::exp2f(biased_exponent); if constexpr (USE_ROWWISE_SCALING) { Vec in; diff --git a/transformer_engine/common/util/rtc.h b/transformer_engine/common/util/rtc.h index 820b16c20..7de1e4d55 100644 --- a/transformer_engine/common/util/rtc.h +++ b/transformer_engine/common/util/rtc.h @@ -59,6 +59,7 @@ class Kernel { template void launch(int device_id, const dim3 grid_dim, const dim3 block_dim, unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) { + cuda_driver::ensure_context_exists(); void *arg_ptrs[] = {const_cast(static_cast(&args))...}; NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y, grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes, diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 420b9ed3b..0d667a0ec 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -183,6 +183,7 @@ __launch_bounds__(unary_kernel_threads) __global__ VectorizedStorer storer(output, N); ComputeType max = 0; ComputeType s = 1; + const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -196,27 +197,28 @@ __launch_bounds__(unary_kernel_threads) __global__ for (int i = 0; i < nvec; ++i) { const ComputeType val = static_cast(loader.separate()[i]); ComputeType temp = OP(val, p); - if constexpr (is_fp8::value) { + if (requires_amax) { __builtin_assume(max >= 0); max = fmaxf(fabsf(temp), max); - + } + if constexpr (is_fp8::value) { temp = temp * s; } - storer.separate()[i] = static_cast(temp); } storer.store(tid, N); } - if constexpr (is_fp8::value) { - // Reduce amax over block - if (amax != nullptr) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + + // Reduce amax over block + if (requires_amax) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } + } + if constexpr (is_fp8::value) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { reciprocal(scale_inv, s); @@ -236,6 +238,7 @@ __launch_bounds__(unary_kernel_threads) __global__ VectorizedStorer storer(output, N); ComputeType max = 0; ComputeType s = 1; + const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -251,10 +254,11 @@ __launch_bounds__(unary_kernel_threads) __global__ const ComputeType val = static_cast(loader.separate()[i]); const ComputeType g = static_cast(grad_loader.separate()[i]); ComputeType temp = OP(val, p) * g; - if constexpr (is_fp8::value) { + if (requires_amax) { __builtin_assume(max >= 0); max = fmaxf(fabsf(temp), max); - + } + if constexpr (is_fp8::value) { temp = temp * s; } @@ -262,16 +266,17 @@ __launch_bounds__(unary_kernel_threads) __global__ } storer.store(tid, N); } - if constexpr (is_fp8::value) { - // Reduce amax over block - if (amax != nullptr) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + + // Reduce amax over block + if (requires_amax) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } + } + if constexpr (is_fp8::value) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { reciprocal(scale_inv, s); @@ -359,6 +364,7 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out break; } } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -393,6 +399,7 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp break; } } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -406,6 +413,7 @@ __launch_bounds__(unary_kernel_threads) __global__ const size_t M = num_aligned_elements * m; ComputeType max = 0; ComputeType s = 1; + const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -425,25 +433,28 @@ __launch_bounds__(unary_kernel_threads) __global__ const ComputeType val = static_cast(loader0.separate()[i]); const ComputeType val2 = static_cast(loader1.separate()[i]); ComputeType temp = static_cast(Activation(val, p) * val2); - if constexpr (is_fp8::value) { + if (requires_amax) { __builtin_assume(max >= 0); max = fmaxf(fabsf(temp), max); + } + if constexpr (is_fp8::value) { temp = temp * s; } storer.separate()[i] = static_cast(static_cast(temp)); } storer.store(id_x, n); } - if constexpr (is_fp8::value) { - // Reduce amax over block - if (amax != nullptr) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + + // Reduce amax over block + if (requires_amax) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } + } + if constexpr (is_fp8::value) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { reciprocal(scale_inv, s); @@ -482,6 +493,7 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c break; } } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -497,6 +509,7 @@ __launch_bounds__(unary_kernel_threads) __global__ const size_t M = num_aligned_elements * m; ComputeType max = 0; ComputeType s = 1; + const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -524,11 +537,13 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; ComputeType after_dgate = grad_val * Activation(gelu_in, p); - if constexpr (is_fp8::value) { + if (requires_amax) { __builtin_assume(max >= 0); max = fmaxf(fabsf(after_dgelu), max); - after_dgelu = after_dgelu * s; max = fmaxf(fabsf(after_dgate), max); + } + if constexpr (is_fp8::value) { + after_dgelu = after_dgelu * s; after_dgate = after_dgate * s; } @@ -538,16 +553,17 @@ __launch_bounds__(unary_kernel_threads) __global__ storer0.store(id_x, n); storer1.store(id_x, n); } - if constexpr (is_fp8::value) { - // Reduce amax over block - if (amax != nullptr) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + + // Reduce amax over block + if (requires_amax) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } + } + if constexpr (is_fp8::value) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { reciprocal(scale_inv, s); @@ -589,6 +605,7 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu break; } } + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index b004de6dc..799becaee 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -984,10 +984,7 @@ using fp8e5m2 = te_hip_fp8_e5m2; #endif //__HIP_PLATFORM_AMD__ using e8m0_t = uint8_t; -constexpr uint32_t FP32_MANTISSA_BITS = 23; -constexpr uint32_t FP32_EXPONENT_BIAS = 127; - -enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 }; +enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 }; template struct Numeric_Traits; @@ -1043,47 +1040,6 @@ struct Quantized_Limits { #endif // !defined(__HIP_DEVICE_COMPILE__) }; -__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { - // TODO: nan/inf needs to be set for any value - // of nan/inf in input not just amax. - if (isnan(val)) { - return 0xFF; - } - if (isinf(val)) { - return 0xFE; - } -#ifdef __HIP_PLATFORM_AMD__ -#define __CUDA_ARCH_HAS_FEATURE__(x) 0 -#endif //__HIP_PLATFORM_AMD__ -#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ - (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) - uint16_t out; - asm volatile( - "{\n" - "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" - "}" - : "=h"(out) - : "f"(val)); - return *reinterpret_cast(&out); -#else - if (val == 0.0f) { - return 0x00; - } - uint32_t val_u32 = *reinterpret_cast(&val); - e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); - uint32_t mantissa = val_u32 & 0x7FFFFF; - // Round up exponent and deal with satfinite. - if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { - ++exponent; - } - return exponent; -#endif -} - -__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { - return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); -} - } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ diff --git a/transformer_engine/debug/features/api.py b/transformer_engine/debug/features/api.py index 13ab6040d..94fc6d129 100644 --- a/transformer_engine/debug/features/api.py +++ b/transformer_engine/debug/features/api.py @@ -5,7 +5,8 @@ """API definition for nvidia-dlframework-inspect.""" import copy -from typing import Dict, Union +import warnings +from typing import Dict, Union, Tuple, Optional from nvdlfw_inspect.base import BaseNamespaceAPI, BaseConfigAPIMapper from nvdlfw_inspect.registry import Registry @@ -101,13 +102,23 @@ def _process_transformer_engine_config(self, config, **kwargs): class TEDefaultFeatures: """Transformer Engine API calls default behavior.""" - def fp8_gemm_enabled(self, config: Dict, layer_name: str, gemm: str, iteration: int) -> bool: + def fp8_gemm_enabled( + self, + config: Dict, + layer_name: str, + gemm: str, + iteration: int, + ) -> bool | Tuple[bool, Optional[int]]: """ If the tensor is not processed using *modify_tensor* and the fp8 recipe is enabled, then the decision whether to cast it to fp8 is based on the value returned by the call *fp8_gemm_enabled*. If the tensor is processed using *modify_tensor* or fp8 autocast is not enabled, the result of this call does not matter. + This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be disabled. + It can return (bool, None) if the feature will never be enabled for that layer and gemm. + Returning the next enabled iteration can help optimize CPU usage. + Parameters ---------- @@ -122,9 +133,9 @@ def fp8_gemm_enabled(self, config: Dict, layer_name: str, gemm: str, iteration: Returns ------- - bool - default is True + Union[bool, Tuple[bool, Optional[int]]] - default is (True, None) """ - return True # if it is false, fp8_gemm will be turned off. Otherwise nothing happens. + return True, None # if it is false, fp8_gemm will be turned off. Otherwise nothing happens. def modify_tensor_enabled( self, @@ -133,9 +144,16 @@ def modify_tensor_enabled( gemm: str, tensor_name: str, iteration: int, - ) -> bool: + ) -> bool | Tuple[bool, Optional[int]]: """ - It is used to determine whether *modify_tensor* will be run for a given GEMM and tensor name. It has **higher priority** than fp8_gemm, if *modify_tensor_enabled* returns True, then modify_tensor call is invoked for the respective tensor no matter what. + It is used to determine whether *modify_tensor* will be run for a given GEMM and tensor name. + It has **higher priority** than fp8_gemm; if *modify_tensor_enabled* returns True or (True, next_enabled_iter), + then modify_tensor call is invoked for the respective tensor no matter what. + + This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled. + It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor. + Returning the next enabled iteration can help optimize CPU usage, especially when the interval between modify_tensor is large. + Returning only a bool is deprecated. Parameters ---------- @@ -153,9 +171,9 @@ def modify_tensor_enabled( Returns ------- - bool - default is False + Union[bool, Tuple[bool, Optional[int]]] - default is (False, None) """ - return False + return False, None def modify_tensor( self, @@ -167,7 +185,7 @@ def modify_tensor( default_quantizer: Quantizer, iteration: int, out: Union[torch.Tensor, QuantizedTensor], - ) -> Union[torch.Tensor, QuantizedTensor, None]: + ) -> torch.Tensor | QuantizedTensor | None: """ It allows tensor modification. For example, feature `FakeQuant` uses it to emulate casting to FP8. @@ -227,6 +245,9 @@ def inspect_tensor( layer_name: str, tensor_name: str, tensor: torch.Tensor, + rowwise_quantized_tensor: Optional[torch.Tensor], + columnwise_quantized_tensor: Optional[torch.Tensor], + quantizer: Optional[Quantizer], iteration: int, tp_group: torch.distributed.ProcessGroup, ) -> None: @@ -243,6 +264,12 @@ def inspect_tensor( one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`], tensor: torch.Tensor tensor in high precision, + rowwise_quantized_tensor: Optional[torch.Tensor] + rowwise quantized tensor, + columnwise_quantized_tensor: Optional[torch.Tensor] + columnwise quantized tensor, + quantizer: Optional[Quantizer] + quantizer, iteration: int iteration number - equal to the number of times `debug_api.step()` was called. tp_group: torch.distributed.ProcessGroup @@ -260,12 +287,15 @@ def inspect_tensor_postquantize( config: Dict, layer_name: str, tensor_name: str, - gemm: str, tensor: torch.Tensor, iteration: int, tp_group: torch.distributed.ProcessGroup, + rowwise: bool, ) -> None: """ + + This is deprecated call, we advise to use *inspect_tensor* instead. + Similar to *inspect_tensor*, but is run after one of the: fp8 cast, modify_tensor if they are run. If none of the fp8 cast or modify_tensor is invoked, then *inspect_tensor_postquantize* is also not invoked. The feature LogFp8Stats uses this call to collect FP8 statistics after the quantization. Parameters @@ -278,8 +308,6 @@ def inspect_tensor_postquantize( one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`], tensor: torch.Tensor tensor in fp8 or processed tensor after the modify_tensor call, - gemm: str - one of [`fprop`, `dgrad`, `wgrad`], iteration: int iteration number - equal to the number of times `debug_api.step()` was called. tp_group: torch.distributed.ProcessGroup @@ -298,9 +326,15 @@ def inspect_tensor_enabled( layer_name: str, tensor_name: str, iteration: int, - ) -> bool: + ) -> bool | Tuple[bool, Optional[int]]: """ - It is a routing call, which is run at the initialization of the layer. If it returns true, then *inspect_tensor* for a given GEMM and tensor will be invoked. + It is a routing call, which is run at the initialization of the layer. + Determines if *inspect_tensor* for a given GEMM and tensor will be invoked. + + This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled. + It can return (bool, None) if the feature will never be enabled for that layer and tensor. + Returning the next enabled iteration can help optimize CPU usage, especially when the interval between inspect_tensor is large. + Returning only a bool is deprecated. Parameters ---------- @@ -316,9 +350,9 @@ def inspect_tensor_enabled( Returns ------- - bool - default is False + Union[bool, Tuple[bool, Optional[int]]] - default is (False, None) """ - return False + return False, None def inspect_tensor_postquantize_enabled( self, @@ -327,11 +361,18 @@ def inspect_tensor_postquantize_enabled( gemm: str, tensor_name: str, iteration: int, - ) -> bool: + ) -> bool | Tuple[bool, Optional[int]]: """ + This is deprecated call, we advise to use *inspect_tensor* and *inspect_tensor_enabled* instead. + It is a routing call, which is run at the initialization of the layer. - If it returns true, then *inspect_tensor_postquantize* for - a given GEMM and tensor will be invoked. + Determines if *inspect_tensor_postquantize* for a given GEMM and tensor will be invoked. + + This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled. + It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor name. + Returning the next enabled iteration can help optimize CPU usage, + especially when the interval between inspect_tensor_postquantize is large. + Returning only a bool is deprecated. Parameters ---------- @@ -349,9 +390,9 @@ def inspect_tensor_postquantize_enabled( Returns ------- - bool - default is False + Union[bool, Tuple[bool, Optional[int]]] - default is (False, None) """ - return False + return False, None @Registry.register_namespace_api(namespace="transformer_engine") @@ -371,8 +412,8 @@ def __init__(self): "modify_tensor": ["tensor_name", "gemm"], "inspect_tensor": ["tensor_name"], "inspect_tensor_postquantize": ["tensor_name"], - "inspect_tensor_enabled": ["tensor_name"], - "inspect_tensor_postquantize_enabled": ["tensor_name"], + "inspect_tensor_enabled": ["tensor_name", "iteration"], + "inspect_tensor_postquantize_enabled": ["tensor_name", "iteration"], "modify_tensor_enabled": ["tensor_name"], } @@ -420,7 +461,7 @@ def routing_condition(self, api_name, config, _, feature_obj, **kwargs): def output_assertions_hook(self, api_name, ret, **kwargs): """Output hooks used to check correctness of the outputs of the API calls.""" if "enabled" in api_name or api_name == "fp8_gemm": - assert isinstance(ret, bool) + assert isinstance(ret, (bool, tuple)) if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]: assert ret is None if api_name == "modify_tensor": @@ -432,6 +473,57 @@ def output_assertions_hook(self, api_name, ret, **kwargs): if kwargs["dtype"] is not None: assert ret.dtype == kwargs["dtype"] + def call_feature(self, call, feat_config, layer_name, **kwargs): + """ + For backward compatibility, remove kwargs that are not needed for the call + """ + if call.__name__ == "inspect_tensor": + kwargs_copy = kwargs.copy() + for k in ["quantizer", "columnwise_quantized_tensor", "rowwise_quantized_tensor"]: + if k not in call.__code__.co_varnames: + kwargs_copy.pop(k) + else: + kwargs_copy = kwargs + + if call.__name__ == "inspect_tensor_postquantize": + warnings.warn( + "inspect_tensor_postquantize is deprecated, use inspect_tensor instead.", + DeprecationWarning, + ) + + return call(feat_config, layer_name, **kwargs_copy) + + def handle_multi_feature_output( + self, api_name, multi_feature_outputs, features_to_invoke, **kwargs + ): + """ + Handle multi-tensor output of the API calls. + """ + if "enabled" in api_name: + # *_enabled feature calls can return bool, or tuple (bool, Optional[int]). + # If any of them returns bool, then we return bool - this means that we cannot state anything + # about enablement in the next steps. + # If all of them return a tuple (bool, Optional[int]), we return the minimum value, + # representing the number of steps after the feature will be enabled next time. + # If the second value is None, that means that the feature will never be enabled. + all_ret_tuple = all( + isinstance(feature_output, tuple) for feature_output in multi_feature_outputs + ) + if all_ret_tuple: + run_current = any(feature_output[0] for feature_output in multi_feature_outputs) + next_iter = None + for feature_output in multi_feature_outputs: + if next_iter is None: + next_iter = feature_output[1] + elif feature_output[1] is not None: + next_iter = min(next_iter, feature_output[1]) + return run_current, next_iter + run_current = any(feature_output for feature_output in multi_feature_outputs) + return run_current, None + return super().handle_multi_feature_output( + api_name, multi_feature_outputs, features_to_invoke, **kwargs + ) + def step(self): """This function is called by the nvidia-dlframework-inspect after every debug_api.step()""" STATS_BUFFERS.log_stats() diff --git a/transformer_engine/debug/features/disable_fp8_gemm.py b/transformer_engine/debug/features/disable_fp8_gemm.py index b2400d1cd..ef2cccbe4 100644 --- a/transformer_engine/debug/features/disable_fp8_gemm.py +++ b/transformer_engine/debug/features/disable_fp8_gemm.py @@ -50,4 +50,4 @@ def fp8_gemm_enabled( # If this feature is invoked, then FP8 GEMM is disabled. # If not, then default behaviour in TransformerEngineAPI # is that fp8_gemm() API call returns True. - return False + return False, iteration + 1 diff --git a/transformer_engine/debug/features/disable_fp8_layer.py b/transformer_engine/debug/features/disable_fp8_layer.py index 7e885fe5e..c3b0e4cca 100644 --- a/transformer_engine/debug/features/disable_fp8_layer.py +++ b/transformer_engine/debug/features/disable_fp8_layer.py @@ -41,7 +41,7 @@ def fp8_gemm_enabled( # If this feature is invoked, then FP8 GEMM is disabled. # If not, then default behavior in TransformerEngineAPI # is that fp8_gemm() API call returns True. - return False + return False, iteration + 1 def parse_config_and_api(self, config, **_kwargs): """Determines whether to run the API diff --git a/transformer_engine/debug/features/fake_quant.py b/transformer_engine/debug/features/fake_quant.py index 4a5b6c34a..4b01f9712 100644 --- a/transformer_engine/debug/features/fake_quant.py +++ b/transformer_engine/debug/features/fake_quant.py @@ -127,14 +127,14 @@ def fp8_gemm_enabled( self, config, layer_name: str, gemm: str, iteration: int ): # pylint: disable=unused-argument """API call responsible for selecting between high-precision and FP8 GEMM execution.""" - return False + return False, None @api_method def modify_tensor_enabled( self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int ): # pylint: disable=unused-argument """API call used to determine whether to run process_tensor() in the forward.""" - return True + return True, iteration + 1 @api_method def modify_tensor( diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index e5c84a9bd..31620211d 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -4,54 +4,119 @@ """LogFp8TensorStats Feature support for nvidia-dlframework-inspect""" -from typing import Dict, Union +from typing import Dict, Optional, List, Tuple +from contextlib import contextmanager import torch - import nvdlfw_inspect.api as debug_api + + from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS -from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor -from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase -from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from transformer_engine.debug.pytorch.debug_state import TEDebugState +from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer +from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter + + +ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"] + + +def _get_recipe_name(quantizer: Optional[Quantizer]): + if quantizer is None: + return "" + if isinstance(quantizer, Float8Quantizer): + return "fp8_delayed_scaling" + if isinstance(quantizer, Float8CurrentScalingQuantizer): + return "fp8_current_scaling" + if isinstance(quantizer, MXFP8Quantizer): + return "mxfp8" + if isinstance(quantizer, Float8BlockQuantizer): + return "fp8_block_scaling" + raise ValueError(f"Unsupported quantizer type: {type(quantizer)}") + + +def _get_new_quantizer(recipe_name, fp8_dtype): + if recipe_name == "fp8_block_scaling": + return Float8BlockQuantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) + if recipe_name == "fp8_current_scaling": + return Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, device=torch.device("cuda"), rowwise=True, columnwise=True + ) + if recipe_name == "mxfp8": + return MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) + if recipe_name == "fp8_delayed_scaling": + raise ValueError("Cannot recreate quantizer for fp8_delayed_scaling") + raise ValueError(f"Unsupported recipe name: {recipe_name}") @Registry.register_feature(namespace="transformer_engine") class LogFp8TensorStats(BaseLogTensorStats): """ - This feature handles logging of FP8 tensor stats. + Logs statistics of quantized tensors. + + Supports computing statistics for current recipe, but also + allows to see what would happend if different recipes were used for these tensors in current iteration. + For example, during delayed-scaling training you may wish to track + "current_scaling_underflows%" to measure the accuracy of the current scaling + factors; note that this requires an extra cast and therefore adds overhead. + Using a logging frequency (`freq`) greater than 1 is recommended in this case. + Computing the stats matching the training recipe does not require an extra cast. + Statistics are identified by the pattern `_` with optional `_columnwise` suffix (e.g. + `delayed_scaling_underflows%` or `mxfp8_scale_inv_min_columnwise`). + One can provide `` only, then the current training recipe is used. - In a distributed setting, the auxiliary stats are computed on each rank and gathered after - the `debug_api.step()` call. Do not forget to invoke `debug_api.step()` at every step to log - stats! + Stats for delayed-scaling cannot be collected if delayed-scaling is not the current training recipe. - `LogFp8TensorStats` supports micro-batching. If multiple forward/backward passes are invoked - per `debug_api.step()`, then stats for all tensors except weights will be accumulated. + In distributed runs each rank first computes its local statistics; the values + are gathered the next time `debug_api.step()` is called. Remember to call + `debug_api.step()` every training step so the logs are flushed. - `LogFp8TensorStats` can induce significant overhead. To mitigate this issue, logging stats - with `freq > 1` is recommended. If `LogFp8TensorStats` is not used in a given step, the - overhead is smaller. If no other feature is used for the layer, the TE layer will - run as fast as it would without `debug_api` initialized. + The feature is micro-batch aware: if several forward/backward passes occur + between successive `debug_api.step()` calls, statistics are accumulated for all + tensors except weights. + + Collecting FP8 statistics is expensive. Choosing a larger `freq` reduces the + overhead, and if the feature is skipped for a step the additional cost is + minimal. When no other debug feature is active, the layer runs at normal + Transformer Engine speed. Parameters ---------- stats: List[str] - list of statistics to log + Each stat is a string of the form `_`, with an optional `_columnwise` suffix (i.e., `__columnwise`). + If only `` is omitted, the current training recipe is used. + For mxfp8 and fp8_block_scaling `_columnwise` suffix can be provided. Then stat is computed on columnwise(transpose) + version of the tensor, which can be numerically different from rowwise (non-transpose) tensors. + "_columnwise" suffix is not supported for fp8_delayed_scaling and fp8_current_scaling. + + recipes: + - fp8_delayed_scaling, + - fp8_current_scaling, + - mxfp8, + - fp8_block_scaling, + + stats: + - underflows% - percentage of non-zero elements of tensor clipped to 0 after quantization, + - overflows% - percentage of elements of tensor that were clipped to the max/min value of the FP8 range - supported only for fp8_delayed_scaling, + - scale_inv_min - minimum of the inverse of the scaling factors, + - scale_inv_max - maximum of the inverse of the scaling factors, + - mse - mean squared error of the quantized tensor and the original tensor = sum((quantized_tensor - original_tensor)**2) / num_elements, - - underflows% - percentage of elements of the tensor equal to 0, tensors/tensors_struct: List[str] list of tensors to log + - activation, + - gradient, + - weight, - - activation - - gradient - - weight freq: Optional[int], default = 1 frequency of logging stats, stats will be logged every `freq` steps start_step: Optional[int], default = None @@ -74,7 +139,7 @@ class LogFp8TensorStats(BaseLogTensorStats): enabled: True tensors_struct: - tensor: activation - stats: [underflows%] + stats: [mxfp8_underflows%] freq: 1 - tensor: gradient stats: [underflows%] @@ -83,42 +148,147 @@ class LogFp8TensorStats(BaseLogTensorStats): end_step: 80 """ - def _get_supported_stats_list(self): - """Returns stats this feature can log.""" - return {"underflows%"} + def check_if_stat_is_supported(self, stat: str, current_recipe: str): + """Returns True if stat is supported, raises ValueError otherwise.""" + columnwise = stat.endswith("_columnwise") + if columnwise: + stat = stat[: -len("_columnwise")] + recipe_from_stat, _ = self.get_recipe_from_stat(stat, default_recipe=current_recipe) + stat_without_recipe = stat.replace(recipe_from_stat + "_", "") + + if current_recipe == "" and recipe_from_stat == "": + raise ValueError( + f"Stat {stat} does not contain a recipe name and the current recipe is not set." + ) + + if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES: + raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}") + + if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise: + raise ValueError( + f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for" + " fp8_delayed_scaling and fp8_current_scaling." + ) + + if recipe_from_stat == "fp8_delayed_scaling" and stat_without_recipe == "overflows%": + return True + + if recipe_from_stat in ["fp8_block_scaling"] and torch.cuda.get_device_capability()[0] < 9: + raise ValueError(f"Stat {stat} needs Hopper or later GPU.") + + if recipe_from_stat == "mxfp8" and torch.cuda.get_device_capability()[0] < 10: + raise ValueError(f"Stat {stat} needs Blackwell or later GPU.") + + supported_stats = ["underflows%", "scale_inv_min", "scale_inv_max", "mse"] + if stat_without_recipe not in supported_stats: + raise ValueError( + f"Stat {stat} contains an unsupported stat name: {stat_without_recipe}" + ) + + return True + + def get_recipe_from_stat(self, stat: str, default_recipe: str = ""): + """Returns the recipe name from the stat string.""" + columnwise_stat = stat.endswith("_columnwise") + for recipe_name in ALL_RECIPE_NAMES: + if recipe_name in stat: + return recipe_name, columnwise_stat + return default_recipe, columnwise_stat + + @contextmanager + def update_aux_dict( + self, + aux_dict: Dict, + recipe_name: str, + quantized_tensor: QuantizedTensor, + quantizer: Quantizer, + original_tensor: torch.Tensor, + recipes_in_stats: List[Tuple[str, bool]], + ): + """ + Updates the aux_dict with the quantized tensor for each recipe provided in recipes_in_stats. + It allows to compute stats for different recipes in the same iteration, + without recomputing the quantized tensor for each recipe for each stat. + Also updates usage of the quantized tensor with rowwise and columnwise usage. + Yields the aux_dict. + Needs to clean after usage, because it possibly change the usage of the quantized tensor. + """ + fp8_dtype = None + if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]: + assert isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer) + ) + fp8_dtype = quantizer.dtype + + aux_dict = { + recipe_name: quantized_tensor, + } + + old_rowwise_usage = quantizer.rowwise_usage + old_columnwise_usage = quantizer.columnwise_usage + for cur_recipe_name, cur_columnwise_stat in recipes_in_stats: + if recipe_name is not cur_recipe_name: + quantizer = _get_new_quantizer(cur_recipe_name, fp8_dtype) + aux_dict[cur_recipe_name] = quantizer(original_tensor) + elif isinstance(quantized_tensor, QuantizedTensor): + if cur_columnwise_stat: + quantized_tensor.update_usage(columnwise_usage=True) + else: + quantized_tensor.update_usage(rowwise_usage=True) + aux_dict[""] = quantized_tensor + aux_dict[cur_recipe_name] = quantized_tensor + try: + yield aux_dict + finally: + if isinstance(quantized_tensor, QuantizedTensor): + quantized_tensor.update_usage( + rowwise_usage=old_rowwise_usage, columnwise_usage=old_columnwise_usage + ) @api_method - def inspect_tensor_postquantize_enabled( - self, config: Dict, layer_name: str, gemm: str, tensor_name: str, iteration: int + def inspect_tensor_enabled( + self, config: Dict, layer_name: str, tensor_name: str, iteration: int ): # pylint: disable=unused-argument """API call used to determine whether to run inspect_tensor_postquantize() in the forward.""" - # check whether logging should happen in this iteration - return self._check_params(config, layer_name, iteration=iteration) + run_current, next_iter = next_enabled_iter( + config.get("start_step", None), + config.get("end_step", None), + config.get("start_end_list", None), + config.get("freq", 1), + iteration, + ) + STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter + return run_current, next_iter @api_method - def inspect_tensor_postquantize( + def inspect_tensor( self, config: Dict, layer_name: str, tensor_name: str, - tensor: Union[torch.Tensor, QuantizedTensor], - rowwise: bool, iteration: int, tp_group: torch.distributed.ProcessGroup, + tensor: torch.Tensor, + rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + quantizer: Optional[Quantizer] = None, ): """ API call used to collect the data about the tensor after process_tensor()/quantization. """ + assert rowwise_quantized_tensor is columnwise_quantized_tensor + assert ( + quantizer is not None + ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe." - assert type(tensor) in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase], ( - f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be a quantized tensor when using" - " log_fp8_tensor_stats. Use log_tensor_stats for high precision tensors." - ) + quantized_tensor = rowwise_quantized_tensor + assert isinstance( + quantized_tensor, QuantizedTensor + ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor." + recipe_name = _get_recipe_name(quantizer) - # This API can be invoked twice - with the tensor and with the transpose. - # We want to collect the stats once. - if not rowwise: - return # tensor was already seen rowwise in the other gemm + for stat in config["stats"]: + self.check_if_stat_is_supported(stat, recipe_name) options = ( config.get("start_step", None), @@ -127,19 +297,9 @@ def inspect_tensor_postquantize( "fp8", ) - skip_reduction = False - reduction_group = debug_api.get_tensor_reduction_group() - reduce_within_microbatch = tensor_name != "weight" - if tensor_name == "weight": - if TEDebugState.weight_tensor_tp_group_reduce: - reduction_group = tp_group - else: - skip_reduction = True - - for stat in config["stats"]: - assert ( - stat in self._get_supported_stats_list() - ), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported." + skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( + tensor_name, tp_group + ) STATS_BUFFERS.try_add_buffer( layer_name=layer_name, @@ -150,10 +310,30 @@ def inspect_tensor_postquantize( reduce_within_microbatch=reduce_within_microbatch, ) - STATS_BUFFERS.feed(layer_name, tensor_name, options, tensor, iteration, skip_reduction) + recipes_in_stats = [ + self.get_recipe_from_stat(stat, default_recipe=recipe_name) for stat in config["stats"] + ] + + with self.update_aux_dict( + aux_dict={}, + recipe_name=recipe_name, + quantized_tensor=quantized_tensor, + quantizer=quantizer, + original_tensor=tensor, + recipes_in_stats=recipes_in_stats, + ) as aux_dict: + STATS_BUFFERS.feed( + layer_name, + tensor_name, + options, + tensor, + iteration, + skip_reduction, + aux_dict=aux_dict, + ) debug_api.log_message( - f"Feature={self.__class__.__name__}, API=inspect_tensor_postquantize: {tensor_name}", + f"Feature={self.__class__.__name__}, API=inspect_tensor: {tensor_name}", layer_name, extra_cachable_args=(tensor_name,), ) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 75ff81d13..7ba2f9f77 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -4,7 +4,7 @@ """LogTensorStats Feature support for nvidia-dlframework-inspect""" -from typing import Dict, Union +from typing import Dict, Optional import torch @@ -12,13 +12,13 @@ from nvdlfw_inspect.registry import Registry, api_method import nvdlfw_inspect.api as debug_api -from transformer_engine.pytorch.tensor import QuantizedTensor +from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS +from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params @Registry.register_feature(namespace="transformer_engine") @@ -97,7 +97,15 @@ def inspect_tensor_enabled( self, config: Dict, layer_name: str, tensor_name: str, iteration: int ): # pylint: disable=unused-argument """API call used to determine whether to run look_at_tensor_before_process() in the forward.""" - return self._check_params(config, layer_name, iteration=iteration) + run_current, next_iter = next_enabled_iter( + config.get("start_step", None), + config.get("end_step", None), + config.get("start_end_list", None), + config.get("freq", 1), + iteration, + ) + STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter + return run_current, next_iter @api_method def inspect_tensor( @@ -105,10 +113,13 @@ def inspect_tensor( config: Dict, layer_name: str, tensor_name: str, - tensor: Union[torch.Tensor, QuantizedTensor], iteration: int, tp_group: torch.distributed.ProcessGroup, - ): + tensor: torch.Tensor, + rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + quantizer: Optional[Quantizer] = None, + ): # pylint: disable=unused-argument """API call used to collect the data about the tensor before process_tensor()/quantization.""" assert ( @@ -125,14 +136,9 @@ def inspect_tensor( config.get("start_end_list", None), ) - skip_reduction = False - reduction_group = debug_api.get_tensor_reduction_group() - reduce_within_microbatch = tensor_name != "weight" - if tensor_name == "weight": - if TEDebugState.weight_tensor_tp_group_reduce: - reduction_group = tp_group - else: - skip_reduction = True + skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( + tensor_name, tp_group + ) for stat in config["stats"]: assert ( diff --git a/transformer_engine/debug/features/per_tensor_scaling.py b/transformer_engine/debug/features/per_tensor_scaling.py index 7b4de0a18..dd1f42cf0 100644 --- a/transformer_engine/debug/features/per_tensor_scaling.py +++ b/transformer_engine/debug/features/per_tensor_scaling.py @@ -91,14 +91,14 @@ def fp8_gemm( self, config, layer_name: str, gemm: str, iteration: int ): # pylint: disable=unused-argument """API call responsible for selecting between high-precision and FP8 GEMM execution.""" - return False + return False, None @api_method def modify_tensor_enabled( self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int ): # pylint: disable=unused-argument """API call used to determine whether to run process_tensor() in the forward.""" - return True + return True, iteration + 1 @api_method def modify_tensor( diff --git a/transformer_engine/debug/features/utils/__init__.py b/transformer_engine/debug/features/utils/__init__.py index 951e25063..aae2ec4e9 100644 --- a/transformer_engine/debug/features/utils/__init__.py +++ b/transformer_engine/debug/features/utils/__init__.py @@ -5,3 +5,58 @@ """ Utils for the debug features. """ + +import torch +import nvdlfw_inspect.api as debug_api + +from transformer_engine.debug.pytorch.debug_state import TEDebugState + + +def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGroup): + """ + Returns the statistics reduction parameters for the tensor. + """ + skip_reduction = False + reduction_group = debug_api.get_tensor_reduction_group() + reduce_within_microbatch = tensor_name != "weight" + if tensor_name == "weight": + if TEDebugState.weight_tensor_tp_group_reduce: + reduction_group = tp_group + else: + skip_reduction = True + return skip_reduction, reduction_group, reduce_within_microbatch + + +def next_enabled_iter(start_step, end_step, start_end_list, freq, iteration): + """ + Determines whether the feature should be enabled at the current iteration, + and computes the next iteration at which the feature will be enabled. + + Returns + ------- + run_current : bool + True if the feature should be enabled at the current iteration. + next_iter : int + The next iteration index at which the feature will be enabled. + """ + + run_current = False + + if start_end_list: + intervals = sorted(start_end_list) + else: + start_step = 0 if start_step is None else start_step + end = float("inf") if end_step is None else end_step + intervals = [(start_step, end)] + + for s, e in intervals: + if iteration % freq == 0 and s <= iteration <= e: + run_current = True + + first = max(iteration + 1, s) + offset = first % freq + candidate = first if offset == 0 else first + (freq - offset) + if candidate <= e: + return run_current, candidate + + return run_current, None # No next iteration found diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index 4be465f8e..f07602d23 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -10,6 +10,7 @@ from collections import defaultdict +from typing import Dict import torch from nvdlfw_inspect.utils import gather_along_first_dim @@ -20,6 +21,7 @@ DEPENDENCIES, stats_to_num, ) +from transformer_engine.debug.pytorch.debug_state import TEDebugState class _Buffer: @@ -65,14 +67,17 @@ def _gather_helper_stats(self) -> torch.Tensor: gathered_buffer, _ = gather_along_first_dim( self._buffer.unsqueeze(0), process_group=self.reduction_group ) - return gathered_buffer[mask.to(bool)] + return gathered_buffer[mask.to(torch.bool)] - def feed(self, tensor, iteration): + def feed(self, tensor, iteration, aux_dict=None): """ feed() is used to add tensor for computing the statistics. Because of the microbatching, feed() can be used multiple times for one log(). + The aux_dict is used to share common computation between different stats. + For example for LogFp8TensorStats in can contain quantized tensors in different precisions. + The main reason of this design: need to combine results for already processed tensors with the result of the new tensor. """ @@ -95,7 +100,7 @@ def feed(self, tensor, iteration): # save stats for tensor to tmp buffer for stat_name in self.stats_to_compute: fn, _ = STATS[stat_name] - self._tmp_buffer[stats_to_num[stat_name]] = fn(tensor) + self._tmp_buffer[stats_to_num[stat_name]] = fn(tensor, aux_dict) # [num_buffers, num_stats] buffers = torch.cat((self._buffer.unsqueeze(0), self._tmp_buffer.unsqueeze(0)), dim=0) @@ -106,7 +111,7 @@ def feed(self, tensor, iteration): self._new_buffer[stats_to_num[stat_name]] = combinator(buffers) else: fn = STATS[stat_name][0] - self._new_buffer[stats_to_num[stat_name]] = fn(tensor) + self._new_buffer[stats_to_num[stat_name]] = fn(tensor, aux_dict) self._buffer.copy_(self._new_buffer) @@ -125,7 +130,6 @@ def log(self): for stat_name in self.stats_to_log: combiner = STATS[stat_name][1] stat_value = combiner(gathered_helper_stats) - MetricLogger.log_scalar( f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value, self.iteration ) @@ -146,10 +150,41 @@ def __init__(self): self.buffers = {} # (layer_name, tensor_name) -> buffer self.reduction_group_to_buffer = defaultdict(list) + # Logging stats involves synchronization between nodes + # and non-trivial cpu overhead. + # It should be only done if absolutely necessary. + # This variables helps to determine if we can reduce. + self.at_least_one_layer_fed = False + self.layers_to_next_iter: Dict[str, int] = {} + + def _if_run_reduction(self) -> bool: + """ + Returns True if reduction should be run. + + This is the case if at least one layer logged stats. + If not, it may be the case that some layer was not run on this node. + If we know that such layers on all other nodes do not log this time, + we can not reduce. If this in not the case, we should reduce. + + To ensure corretness, we assume that every layer is invoked at first forward pass. + If this is not the case, hang might happen. + """ + if self.at_least_one_layer_fed: + return True + iteration = TEDebugState.get_iteration() + for _, next_iter in self.layers_to_next_iter.items(): + # Note that layer can be not run for many iterations, + # in this case we will synchronize until every step until we get any information from it. + if iteration >= next_iter: + return True + return False + def reset(self): """Resets all buffers.""" self.buffers = {} # (layer_name, tensor_name) -> buffer self.reduction_group_to_buffer = defaultdict(list) + self.at_least_one_layer_fed = False + self.layers_to_next_iter: Dict[str, int] = {} def try_add_buffer( self, layer_name, tensor_name, stats, options, reduction_group, reduce_within_microbatch @@ -161,14 +196,25 @@ def try_add_buffer( self.buffers[(layer_name, tensor_name, options)] = buffer self.reduction_group_to_buffer[reduction_group].append(buffer) - def feed(self, layer_name, tensor_name, options, tensor, iteration, skip_reduction): - """Feeds the tensor into the respective buffer.""" + def feed( + self, layer_name, tensor_name, options, tensor, iteration, skip_reduction, aux_dict=None + ): + """ + Feeds the tensor into the respective buffer. + + The aux_dict is used to share common computation between different stats. + For example for LogFp8TensorStats in can contain quantized tensors in different precisions. + """ + self.at_least_one_layer_fed = True buffer = self.buffers[(layer_name, tensor_name, options)] - buffer.feed(tensor, iteration) + buffer.feed(tensor, iteration, aux_dict) buffer.skip_reduction = skip_reduction def log_stats(self): """Logs the stats from all the buffers.""" + if not self._if_run_reduction(): + return {} + output = {} for reduction_group, buffers in self.reduction_group_to_buffer.items(): changed_buffers = [ @@ -181,7 +227,7 @@ def log_stats(self): for _, buffer in changed_buffers: stats = buffer.log() output.update(stats) - + self.at_least_one_layer_fed = False return output diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index ed32de1ae..2fa6985ac 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -8,8 +8,9 @@ import math import torch - -MAX_FP8_VALUE_INT8 = 126 +import torch.nn.functional as F +import transformer_engine_torch as tex +from transformer_engine.common.recipe import Format @torch.compile @@ -49,6 +50,29 @@ def compute_std(variances, numels, sums): return torch.sqrt(compute_variance(variances, numels, sums)) +def compute_fp8_delayed_scaling_overflows_num(tensor, quantized_tensor): + """Computes the overflows of the tensor.""" + scale_inv = quantized_tensor._scale_inv + dtype = quantized_tensor._fp8_dtype + + # Map each supported FP8 dtype to its corresponding max forward value. + dtype_to_max = { + tex.DType.kFloat8E4M3: Format.E4M3.value.max_fwd, + tex.DType.kFloat8E5M2: Format.E5M2.value.max_fwd, + } + + if dtype not in dtype_to_max: + raise ValueError( + f"Unsupported FP8 dtype {dtype} passed to compute_fp8_delayed_scaling_overflows_num()." + ) + + fp8_max = dtype_to_max[dtype] + fp8_min = -fp8_max + + overflows = (tensor > fp8_max * scale_inv) | (tensor < fp8_min * scale_inv) + return overflows.sum() + + # buffers is tensor of shape [nr_buffers, nr_stats] def _get(buffers, stat_name): stat_nr = stats_to_num[stat_name] @@ -68,10 +92,12 @@ def _get(buffers, stat_name): "cur_amax": 9, "dynamic_range_top": 10, "dynamic_range_bottom": 11, - "underflows_num": 12, - "std": 13, - "dynamic_range": 14, - "underflows%": 15, + "std": 12, + "dynamic_range": 13, + "fp8_delayed_scaling_overflows_num": 14, + "fp8_delayed_scaling_overflows%": 15, + "overflows_num": 16, + "overflows%": 17, } DEPENDENCIES = { @@ -87,62 +113,217 @@ def _get(buffers, stat_name): "cur_amax": {"cur_amax"}, "dynamic_range_top": {"dynamic_range_top"}, "dynamic_range_bottom": {"dynamic_range_bottom"}, - "underflows_num": {"underflows_num"}, "std": {"variance", "numel", "sum"}, "dynamic_range": {"dynamic_range_top", "dynamic_range_bottom"}, - "underflows%": {"underflows_num", "numel"}, + "fp8_delayed_scaling_overflows_num": {"fp8_delayed_scaling_overflows_num"}, + "fp8_delayed_scaling_overflows%": {"fp8_delayed_scaling_overflows_num", "numel"}, + "overflows_num": {"overflows_num"}, + "overflows%": {"overflows_num", "numel"}, } STATS = { - "min": (torch.min, lambda buffers: min(_get(buffers, "min"))), - "max": (torch.max, lambda buffers: max(_get(buffers, "max"))), - "sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))), - "mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))), + "min": (lambda x, aux_dict: torch.min(x), lambda buffers: min(_get(buffers, "min"))), + "max": (lambda x, aux_dict: torch.max(x), lambda buffers: max(_get(buffers, "max"))), + "sum": (lambda x, aux_dict: torch.sum(x), lambda buffers: sum(_get(buffers, "sum"))), + "mean": ( + lambda x, aux_dict: torch.mean(x), + lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel")), + ), "numel": ( - lambda x: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(), + lambda x, aux_dict: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(), lambda buffers: sum(_get(buffers, "numel")), ), - "l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))), + "l1_norm": ( + lambda x, aux_dict: torch.norm(x, p=1), + lambda buffers: sum(_get(buffers, "l1_norm")), + ), "l2_norm_square": ( - lambda x: torch.sum(x**2), + lambda x, aux_dict: torch.sum(x**2), lambda buffers: sum(_get(buffers, "l2_norm_square")), ), "l2_norm": ( - lambda x: torch.norm(x, p=2), + lambda x, aux_dict: torch.norm(x, p=2), lambda buffers: math.sqrt(sum(_get(buffers, "l2_norm_square"))), ), "variance": ( - torch.var, + lambda x, aux_dict: torch.var(x), lambda buffers: compute_variance( _get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum") ), ), - "cur_amax": (lambda x: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))), + "cur_amax": (lambda x, aux_dict: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))), "dynamic_range_top": ( - _compute_dynamic_range_top, + lambda x, aux_dict: _compute_dynamic_range_top(x), lambda buffers: max(_get(buffers, "dynamic_range_top")), ), "dynamic_range_bottom": ( - _compute_dynamic_range_bottom, + lambda x, aux_dict: _compute_dynamic_range_bottom(x), lambda buffers: min(_get(buffers, "dynamic_range_bottom")), ), - "underflows_num": ( - lambda x: (x.get_data_tensors()[0] == 0).sum(), - lambda buffers: sum(_get(buffers, "underflows_num")), - ), "std": ( - torch.std, + lambda x, aux_dict: torch.std(x), lambda buffers: compute_std( _get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum") ), ), "dynamic_range": ( - lambda x: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x), + lambda x, aux_dict: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x), lambda buffers: max(_get(buffers, "dynamic_range_top")) - min(_get(buffers, "dynamic_range_bottom")), ), - "underflows%": ( - lambda x: (x.get_data_tensors()[0] == 0).sum() / x.get_data_tensors()[0].numel() * 100, - lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")), + "fp8_delayed_scaling_overflows_num": ( + lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num( + x, aux_dict["fp8_delayed_scaling"] + ), + lambda buffers: sum(_get(buffers, "fp8_delayed_scaling_overflows_num")), + ), + "fp8_delayed_scaling_overflows%": ( + lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num( + x, aux_dict["fp8_delayed_scaling"] + ) + / x.numel() + * 100, + lambda buffers: 100 + * sum(_get(buffers, "fp8_delayed_scaling_overflows_num")) + / sum(_get(buffers, "numel")), + ), + "overflows_num": ( + lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(x, aux_dict[""]), + lambda buffers: sum(_get(buffers, "overflows_num")), + ), + "overflows%": ( + lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(x, aux_dict[""]) + / x.numel() + * 100, + lambda buffers: 100 * sum(_get(buffers, "overflows_num")) / sum(_get(buffers, "numel")), ), } + +FP8_NEGATIVE_ZERO = 128 # represnts -0.0 in fp8 + + +def count_nonzero_fp8(fp8_data: torch.Tensor) -> torch.Tensor: + """Count the number of non-zero elements in the fp8 data.""" + fp8_data = fp8_data.view(dtype=torch.uint8) + zero_vals = torch.tensor([0, FP8_NEGATIVE_ZERO], device=fp8_data.device, dtype=torch.uint8) + return fp8_data.numel() - torch.isin(fp8_data, zero_vals).sum() + + +def add_underflows_stats(recipe_name: str, columnwise: bool = False): + """Register *both* underflow stats (num and %) for the given recipe.""" + columnwise_suffix = "_columnwise" if columnwise else "" + + # Stat names + stat_num = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows_num{columnwise_suffix}" + stat_pct = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows%{columnwise_suffix}" + + stats_to_num[stat_num] = len(stats_to_num) + stats_to_num[stat_pct] = len(stats_to_num) + + STATS[stat_num] = ( + lambda x, aux_dict: x.count_nonzero() + - count_nonzero_fp8( + aux_dict[recipe_name].get_data_tensors( + rowwise_data=not columnwise, columnwise_data=columnwise + ) + ), + lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)), + ) + STATS[stat_pct] = ( + lambda x, aux_dict: ( + x.count_nonzero() + - count_nonzero_fp8( + aux_dict[recipe_name].get_data_tensors( + rowwise_data=not columnwise, columnwise_data=columnwise + ) + ) + ) + / aux_dict[recipe_name].numel() + * 100, + lambda buffers, _sn_num=stat_num: 100 + * sum(_get(buffers, _sn_num)) + / sum(_get(buffers, "numel")), + ) + + DEPENDENCIES[stat_num] = {stat_num} + DEPENDENCIES[stat_pct] = {stat_num, "numel"} + + +def add_scale_inv_stats(recipe_name: str, columnwise: bool = False): + """Register *both* scale-inv min and max stats for a given recipe. + + This replaces the earlier separate helpers and avoids duplicated boilerplate. + """ + # Determine which attribute holds the scale-inverse tensor. + + def get_scale_inv(quantized_tensor, columnwise): + if hasattr(quantized_tensor, "_scale_inv"): + return getattr(quantized_tensor, "_scale_inv") + if columnwise: + return getattr(quantized_tensor, "_columnwise_scale_inv") + return getattr(quantized_tensor, "_rowwise_scale_inv") + + columnwise_suffix = "_columnwise" if columnwise else "" + # Prepare stat names. + stat_name_min = ( + f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_min{columnwise_suffix}" + ) + stat_name_max = ( + f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_max{columnwise_suffix}" + ) + + # Assign indices in `stats_to_num` (order matters — keep insertion order deterministic). + stats_to_num[stat_name_min] = len(stats_to_num) + stats_to_num[stat_name_max] = len(stats_to_num) + + # Capture the attribute name inside lambdas via default args to avoid late binding. + STATS[stat_name_min] = ( + lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).min(), + lambda buffers, _sn=stat_name_min: min(_get(buffers, _sn)), + ) + STATS[stat_name_max] = ( + lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).max(), + lambda buffers, _sn=stat_name_max: max(_get(buffers, _sn)), + ) + + DEPENDENCIES[stat_name_min] = {stat_name_min} + DEPENDENCIES[stat_name_max] = {stat_name_max} + + +def add_mse_stats(recipe_name: str, columnwise: bool = False): + """Register mse and total_square_error stats for the recipe.""" + columnwise_suffix = "_columnwise" if columnwise else "" + + stat_mse = f"{recipe_name}{'_' if recipe_name != '' else ''}mse{columnwise_suffix}" + stat_err = ( + f"{recipe_name}{'_' if recipe_name != '' else ''}total_square_error{columnwise_suffix}" + ) + + stats_to_num[stat_mse] = len(stats_to_num) + stats_to_num[stat_err] = len(stats_to_num) + + STATS[stat_mse] = ( + lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="mean"), + lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err)) + / sum(_get(buffers, "numel")), + ) + STATS[stat_err] = ( + lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="sum"), + lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err)), + ) + + DEPENDENCIES[stat_err] = {stat_err} + DEPENDENCIES[stat_mse] = {stat_mse, stat_err, "numel"} + + +for _columnwise in [True, False]: + for _recipe_name in [ + "", # default recipe + "fp8_delayed_scaling", + "mxfp8", + "fp8_current_scaling", + "fp8_block_scaling", + ]: + add_underflows_stats(_recipe_name, _columnwise) + add_scale_inv_stats(_recipe_name, _columnwise) + add_mse_stats(_recipe_name, _columnwise) diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 2b859800a..d564ca8e9 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -22,6 +22,7 @@ prepare_for_saving, restore_from_saved, ) +from transformer_engine.debug.pytorch.debug_state import TEDebugState aten = torch.ops.aten @@ -53,14 +54,13 @@ def __init__( parent_quantizer: Optional[Quantizer], tp_group: torch.distributed.ProcessGroup, ): - import nvdlfw_inspect.api as debug_api super().__init__(rowwise=True, columnwise=True) self.layer_name = layer_name self.tensor_name = tensor_name self.parent_quantizer = parent_quantizer self.tp_group = tp_group # used in inspect_tensor calls - self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count + self.iteration = TEDebugState.get_iteration() # .internal = True is slightly faster, but results # in errors when caching the weights. @@ -70,6 +70,12 @@ def __init__( self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name] + # next iteration when this quantizer will call any API + # it is None at the init and it is computed after_enabled api calls. + # None at the beginning means that if nothing will be done, + # this quantizer will never call any API. + self.next_debug_iter = None + # The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled, # rowwise_tensor_plan, and columnwise_tensor_plan are computed. # These fields indicate the path where API calls will be inserted. @@ -102,15 +108,21 @@ def get_plans_for_output_tensors(self) -> Tuple[bool, str]: """ import nvdlfw_inspect.api as debug_api - inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( - layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + inspect_tensor_enabled = self.process_enabled_api_call( + debug_api.transformer_engine.inspect_tensor_enabled( + layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + ) ) - modify_enabled = debug_api.transformer_engine.modify_tensor_enabled( - layer_name=self.layer_name, - gemm=self.rowwise_gemm_name, - tensor_name=self.tensor_name, - iteration=self.iteration, + + modify_enabled = self.process_enabled_api_call( + debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) ) + plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION return inspect_tensor_enabled, plan @@ -121,10 +133,13 @@ def get_enabled_look_at_tensors(self): """ import nvdlfw_inspect.api as debug_api - inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( - layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + inspect_tensor_enabled = self.process_enabled_api_call( + debug_api.transformer_engine.inspect_tensor_enabled( + layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + ) ) - inspect_tensor_postquantize_enabled_rowwise = ( + + inspect_tensor_postquantize_enabled_rowwise = self.process_enabled_api_call( debug_api.transformer_engine.inspect_tensor_postquantize_enabled( layer_name=self.layer_name, tensor_name=self.tensor_name, @@ -132,7 +147,8 @@ def get_enabled_look_at_tensors(self): gemm=self.rowwise_gemm_name, ) ) - inspect_tensor_postquantize_enabled_columnwise = ( + + inspect_tensor_postquantize_enabled_columnwise = self.process_enabled_api_call( debug_api.transformer_engine.inspect_tensor_postquantize_enabled( layer_name=self.layer_name, tensor_name=self.tensor_name, @@ -140,7 +156,6 @@ def get_enabled_look_at_tensors(self): gemm=self.columnwise_gemm_name, ) ) - return ( inspect_tensor_enabled, inspect_tensor_postquantize_enabled_rowwise, @@ -158,42 +173,54 @@ def get_tensors_plan(self): rowwise_plan = None columnwise_plan = None - modify_rowwise = debug_api.transformer_engine.modify_tensor_enabled( - layer_name=self.layer_name, - gemm=self.rowwise_gemm_name, - tensor_name=self.tensor_name, - iteration=self.iteration, + modify_rowwise = self.process_enabled_api_call( + debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) ) + if modify_rowwise: rowwise_plan = API_CALL_MODIFY else: if self.parent_quantizer is not None: - fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( - layer_name=self.layer_name, - gemm=self.rowwise_gemm_name, - iteration=self.iteration, + fp8_quantize = self.process_enabled_api_call( + debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + iteration=self.iteration, + ) ) + if fp8_quantize: rowwise_plan = STANDARD_FP8_QUANTIZE if rowwise_plan is None: rowwise_plan = HIGH_PRECISION if self.columnwise_gemm_name is not None: - modify_columnwise = debug_api.transformer_engine.modify_tensor_enabled( - layer_name=self.layer_name, - gemm=self.columnwise_gemm_name, - tensor_name=self.tensor_name, - iteration=self.iteration, + modify_columnwise = self.process_enabled_api_call( + debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.columnwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) ) + if modify_columnwise: columnwise_plan = API_CALL_MODIFY else: if self.parent_quantizer is not None: - fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( - layer_name=self.layer_name, - gemm=self.columnwise_gemm_name, - iteration=self.iteration, + fp8_quantize = self.process_enabled_api_call( + debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=self.layer_name, + gemm=self.columnwise_gemm_name, + iteration=self.iteration, + ) ) + if fp8_quantize: columnwise_plan = STANDARD_FP8_QUANTIZE if columnwise_plan is None: @@ -229,8 +256,11 @@ def _call_inspect_tensor_api( "layer_name": self.layer_name, "tensor": tensor, "tensor_name": self.tensor_name, - "iteration": debug_api.DEBUG_MANAGER._trainer_iteration_count, + "iteration": TEDebugState.get_iteration(), "tp_group": self.tp_group, + "columnwise_quantized_tensor": columnwise_gemm_tensor, + "rowwise_quantized_tensor": rowwise_gemm_tensor, + "quantizer": self.parent_quantizer, } if tensor is not None and self.inspect_tensor_enabled: debug_api.transformer_engine.inspect_tensor(**args) @@ -238,6 +268,10 @@ def _call_inspect_tensor_api( if self.output_tensor: return + del args["columnwise_quantized_tensor"] + del args["rowwise_quantized_tensor"] + del args["quantizer"] + if ( self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] and self.inspect_tensor_postquantize_enabled_rowwise @@ -245,6 +279,7 @@ def _call_inspect_tensor_api( args["tensor"] = rowwise_gemm_tensor args["rowwise"] = True debug_api.transformer_engine.inspect_tensor_postquantize(**args) + if ( self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] and self.inspect_tensor_postquantize_enabled_columnwise @@ -270,22 +305,14 @@ def quantize( # 1. If there is fp8 quantization in at least one of the gemms, # the quantization using the self.parent_quantizer is performed. - # rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise - rowwise_gemm_quantize = ( - self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE - ) - columnwise_gemm_quantize = ( - self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE - ) - if columnwise_gemm_quantize and not rowwise_gemm_quantize: - rowwise_gemm_quantize = True # only columnwise quantization not implemented + self._update_parent_quantizer_usage() + # Only columnwise quantization is not supported. + if self.parent_quantizer is not None: + if not self.parent_quantizer.rowwise_usage and self.parent_quantizer.columnwise_usage: + self.parent_quantizer.set_usage(rowwise=True) rowwise_gemm_tensor, columnwise_gemm_tensor = None, None if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: - self.parent_quantizer.set_usage( - rowwise=True, - columnwise=columnwise_gemm_quantize, # columnwise usage only is not supported - ) quantized_tensor = self.parent_quantizer(tensor) # if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8, # one tensor with columnwise=True and rowwise=True is computed @@ -341,7 +368,6 @@ def quantize( quantizer=self, layer_name=self.layer_name, tensor_name=self.tensor_name, - original_tensor=tensor, ) def process_gemm_output(self, tensor: torch.Tensor): @@ -375,6 +401,26 @@ def make_empty( return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device) return torch.empty(shape, dtype=dtype, device=device) + def any_feature_enabled(self) -> bool: + """Returns bool if there is at least one API call enabled.""" + if self.output_tensor: + return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY + # pylint: disable=too-many-boolean-expressions + if ( + self.inspect_tensor_enabled + or self.inspect_tensor_postquantize_enabled_rowwise + or self.inspect_tensor_postquantize_enabled_columnwise + or self.rowwise_tensor_plan == API_CALL_MODIFY + or self.columnwise_tensor_plan == API_CALL_MODIFY + ): + return True + if self.parent_quantizer is not None: + if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE: + return True + if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE: + return True + return False + def calibrate(self, tensor: torch.Tensor): """Calibration override, should not be invoked.""" raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported") @@ -446,29 +492,70 @@ def update_quantized( self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor) - def any_feature_enabled(self) -> bool: - """Returns bool if there is at least one API call enabled.""" - if self.output_tensor: - return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY - if ( - self.inspect_tensor_enabled - or self.inspect_tensor_postquantize_enabled_rowwise - or self.inspect_tensor_postquantize_enabled_columnwise - or self.rowwise_tensor_plan == API_CALL_MODIFY - or self.columnwise_tensor_plan == API_CALL_MODIFY - ): - return True - if self.parent_quantizer is not None: - if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE: - return True - if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE: - return True - return False + def get_next_debug_iter(self) -> Optional[int]: + """ + Returns the next iteration for which the debug is enabled for this tensor. + If the next iteration is None, then the debug is not enabled for this tensor. + """ + return self.next_debug_iter def _get_compatible_recipe(self) -> Union[type[Recipe], None]: """Probably not needed for debug quantizer""" return None + def process_enabled_api_call( + self, enabled_call_output: bool | Tuple[bool, Optional[int]] + ) -> bool: + """ + Process enabled API call output. + Updates self.next_debug_iter field accordingly. + Return the bool representing if the API call is enabled. + """ + if isinstance(enabled_call_output, tuple): + assert len(enabled_call_output) == 2, "Expected a tuple of length 2" + enabled_bool, next_iter = enabled_call_output + else: + enabled_bool = enabled_call_output + next_iter = self.iteration + 1 + + if self.next_debug_iter is None: + self.next_debug_iter = next_iter + elif next_iter is not None: + # If next iter is None, that means that call will never be enabled. + self.next_debug_iter = min(self.next_debug_iter, next_iter) + + return enabled_bool + + def supports_only_rowwise_all_gather(self) -> bool: + if self.parent_quantizer is not None: + return self.parent_quantizer.supports_only_rowwise_all_gather() + return False + + def _update_parent_quantizer_usage(self): + """ + Updates the usage of the parent quantizer. + """ + rowwise_gemm_quantize = ( + self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + ) + columnwise_gemm_quantize = ( + self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + ) + + if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + self.parent_quantizer.set_usage( + rowwise=rowwise_gemm_quantize, + columnwise=columnwise_gemm_quantize, + ) + + def set_usage(self, rowwise: bool = None, columnwise: bool = None): + """ + Sets the usage of the quantizer. + """ + super().set_usage(rowwise=rowwise, columnwise=columnwise) + if not self.output_tensor: + self._update_parent_quantizer_usage() + class DebugQuantizedTensor(QuantizedTensorBase): """ @@ -484,7 +571,6 @@ def __init__( quantizer, layer_name=None, tensor_name=None, - original_tensor=None, ): self.rowwise_gemm_tensor = rowwise_gemm_tensor @@ -492,7 +578,6 @@ def __init__( self.quantizer = quantizer self._layer_name = layer_name self._tensor_name = tensor_name - self._original_tensor = original_tensor def prepare_for_saving(self): """ " Prepare for saving method override""" @@ -501,6 +586,7 @@ def prepare_for_saving(self): if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor else [self.rowwise_gemm_tensor] ) + tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save) self.tensors_to_save = tensor_objects_list # pylint: disable=unbalanced-tuple-unpacking @@ -519,6 +605,7 @@ def restore_from_saved(self, tensors): else: self.rowwise_gemm_tensor = tensor_objects_list[0] self.columnwise_gemm_tensor = self.rowwise_gemm_tensor + return saved_tensors def quantize_(self, tensor, *, noop_flag=None): @@ -542,3 +629,27 @@ def size(self): def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None): """Update usage of the tensor.""" + if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor: + # If the same object is used both for rowwise and columnwise gemms, + # there is no benefit in erasing the usage of one of them. + # And there are scenarios when not deleting the usage of one of them is needed. + # For example when we want to recreate columnwise from rowwise. + if rowwise_usage is False: + self.rowwise_gemm_tensor = None + if columnwise_usage is False: + self.columnwise_gemm_tensor = None + + if isinstance(self.rowwise_gemm_tensor, QuantizedTensor): + self.rowwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage) + if isinstance(self.columnwise_gemm_tensor, QuantizedTensor): + self.columnwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage) + + if rowwise_usage and self.rowwise_gemm_tensor is None: + raise RuntimeError( + "Cannot recreate rowwise tensor from columnwise tensor in debug mode." + ) + + if columnwise_usage and self.columnwise_gemm_tensor is None: + raise RuntimeError( + "Cannot recreate columnwise tensor from rowwise tensor is debug mode." + ) diff --git a/transformer_engine/debug/pytorch/debug_state.py b/transformer_engine/debug/pytorch/debug_state.py index 11edb3641..c47e859bb 100644 --- a/transformer_engine/debug/pytorch/debug_state.py +++ b/transformer_engine/debug/pytorch/debug_state.py @@ -62,6 +62,13 @@ def set_weight_tensor_tp_group_reduce(cls, enabled): """Sets weight tensor reduction mode.""" cls.weight_tensor_tp_group_reduce = enabled + @classmethod + def get_iteration(cls): + """Returns the current iteration.""" + import nvdlfw_inspect.api as debug_api + + return debug_api.DEBUG_MANAGER._trainer_iteration_count + def set_weight_tensor_tp_group_reduce(enabled): """Sets weight tensor reduction mode.""" diff --git a/transformer_engine/debug/pytorch/utils.py b/transformer_engine/debug/pytorch/utils.py index 4aea05333..18ed3556f 100644 --- a/transformer_engine/debug/pytorch/utils.py +++ b/transformer_engine/debug/pytorch/utils.py @@ -4,6 +4,25 @@ """Utils functions for the debug module.""" +from typing import Optional + + +def next_iter_when_debug_should_be_run(quantizers) -> Optional[int]: + """ + Returns next iteration at which the debug should be run. + If debug will never be run for this layer, returns None. + """ + + out = None + for q in quantizers: + if q.get_next_debug_iter() is not None: + if out is None: + out = q.get_next_debug_iter() + else: + out = min(out, q.get_next_debug_iter()) + + return out + def any_feature_enabled(quantizers): """Returns True if at least one API call is made from DebugQuantizer.""" diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 55ffad93e..0b5e43402 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -38,19 +38,10 @@ from .quantize import NVTE_FP8_COLLECTION_NAME from .sharding import MeshResource -from .sharding import MajorShardingType, ShardingResource, ShardingType from ..common.utils import deprecate_wrapper from ..common.utils import DeprecatedEnum -MajorShardingType = DeprecatedEnum( - MajorShardingType, "MajorShardingType is deprecating in the near feature." -) -ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.") -ShardingResource = deprecate_wrapper( - ShardingResource, - "ShardingResource is renamed to MeshResource, and will be removed in the near feature.", -) __all__ = [ "NVTE_FP8_COLLECTION_NAME", @@ -58,9 +49,6 @@ "update_collections", "get_delayed_scaling", "MeshResource", - "MajorShardingType", - "ShardingResource", - "ShardingType", "flax", "quantize", ] diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index ef6def2d0..12b35ec43 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -14,7 +14,7 @@ from . import cpp_extensions as tex -from .quantize.tensor import ScaledTensor +from .quantize.tensor import NoScaleTensor from .quantize.quantizer import Quantizer @@ -22,7 +22,7 @@ def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, -) -> Union[jnp.ndarray, ScaledTensor]: +) -> jnp.ndarray: """Apply activation functions to input tensor with optional quantization. This function applies a sequence of activation functions to the input tensor. @@ -72,8 +72,8 @@ def _activation_fwd_rule(x, activation_type, quantizer): Tuple of (output, context) for backward pass """ fwd_output = tex.act_lu(x, activation_type, quantizer) - if isinstance(fwd_output, ScaledTensor): - fwd_output = fwd_output.dequantize() + # This is a no-op for higher-precision tensors + fwd_output = fwd_output.dequantize() return fwd_output, (x, quantizer) @@ -91,6 +91,10 @@ def _activation_bwd_rule(activation_type, ctx, g): (x, _) = ctx assert x.dtype == g.dtype dx = tex.dact_lu(g, x, activation_type) + # No quantization is used in this VJP backward, so the output should + # always be a NoScaleTensor + assert isinstance(dx, NoScaleTensor) + dx = dx.data return (dx, None) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index fe4109cee..093146162 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -855,7 +855,7 @@ def fused_attn_thd( return output -@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)) +@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)) def _fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -872,6 +872,7 @@ def _fused_attn( context_parallel_strategy: CPStrategy, context_parallel_causal_load_balanced: bool, context_parallel_axis: str, + context_checkpoint_name: str = "context", ): output, _ = _fused_attn_fwd_rule( qkv, @@ -889,6 +890,7 @@ def _fused_attn( context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis, + context_checkpoint_name=context_checkpoint_name, ) return output @@ -909,6 +911,7 @@ def _fused_attn_fwd_rule( context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis, + context_checkpoint_name, ): output, softmax_aux, rng_state = tex.fused_attn_fwd( qkv, @@ -927,9 +930,9 @@ def _fused_attn_fwd_rule( context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) - output = checkpoint_name(output, "context") - softmax_aux = checkpoint_name(softmax_aux, "context") - rng_state = checkpoint_name(rng_state, "context") + output = checkpoint_name(output, context_checkpoint_name) + softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name) + rng_state = checkpoint_name(rng_state, context_checkpoint_name) return output, ( qkv, bias, @@ -952,9 +955,11 @@ def _fused_attn_bwd_rule( context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis, + context_checkpoint_name, ctx, dz, ): + del context_checkpoint_name ( qkv, bias, @@ -1012,6 +1017,7 @@ def fused_attn( context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", + context_checkpoint_name: str = "context", ): """ Perform cuDNN fused attention. @@ -1044,6 +1050,7 @@ def fused_attn( context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. + context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass. Returns: (jnp.ndarray): The output tensor from the fused attention. @@ -1116,6 +1123,7 @@ def fused_attn( context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, + context_checkpoint_name=context_checkpoint_name, ) return output diff --git a/transformer_engine/jax/checkpoint_policies.py b/transformer_engine/jax/checkpoint_policies.py new file mode 100644 index 000000000..a03db09b9 --- /dev/null +++ b/transformer_engine/jax/checkpoint_policies.py @@ -0,0 +1,38 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Checkpoint policies for Transformer Engine in JAX. + +This module provides JAX checkpoint policies that are compatible with Transformer Engine's custom primitives. +""" + +import jax +from .cpp_extensions.gemm import GemmPrimitive, GroupedGemmPrimitive + + +__all__ = [ + "te_gemms_saveable", + "dots_and_te_gemms_with_no_batch_dims", + "checkpoint_dots_and_te_gemms", +] + + +def te_gemms_saveable(prim, *_, **__) -> bool: + """Checkpoint policy for Transformer Engine GEMMs.""" + is_te_gemm = prim in {GemmPrimitive.outer_primitive, GroupedGemmPrimitive.outer_primitive} + # Workaround to include JAX's scaled_matmul until JAX checkpoint policies for dots are + # updated to include it. + is_jax_scaled_matmul = prim.name == "scaled_matmul_wrapper" + + return is_te_gemm or is_jax_scaled_matmul + + +dots_and_te_gemms_with_no_batch_dims = jax.checkpoint_policies.save_from_both_policies( + jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, + te_gemms_saveable, +) + +checkpoint_dots_and_te_gemms = jax.checkpoint_policies.save_from_both_policies( + jax.checkpoint_policies.checkpoint_dots, + te_gemms_saveable, +) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 0cd8f5a36..ef2643359 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -32,7 +32,7 @@ ) from .quantization import _jax_dbias, _quantize_dbias_impl from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp -from ..quantize import ScaledTensor, ScaledTensorFactory +from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, QuantizeLayout, @@ -922,14 +922,14 @@ def shardy_sharding_rule( class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): - """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): - """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]: +def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ @@ -948,11 +948,11 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S x = jnp.squeeze(x, axis=-2) if quantizer: return quantizer.quantize(x, flatten_axis=-1) - return x + return NoScaleTensor(data=x, amax=None) def _jax_quantize_dact_dbias( - dz: jnp.ndarray, + dz: Union[jnp.ndarray, NoScaleTensor], x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, @@ -970,7 +970,9 @@ def _jax_quantize_dact_dbias( _, vjp_func = jax.vjp( partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) ) - (dx,) = vjp_func(dz.astype(jnp.float32)) + # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. + dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) + (dx,) = vjp_func(dz) dbias = None if is_dbias: @@ -980,6 +982,7 @@ def _jax_quantize_dact_dbias( dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2) else: dx = dx.astype(x.dtype) + dx = NoScaleTensor(data=dx, amax=None) return dx, dbias @@ -988,7 +991,6 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - noop_scaled_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -997,7 +999,6 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: If quantizer is None: @@ -1042,16 +1043,16 @@ def act_lu( is_outer=True, ) out = out.reshape(output_shape) - if noop_scaled_tensor: - return ScaledTensorFactory.create_2x( - out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype - ) + out = NoScaleTensor( + data=out, + amax=None, + ) return out if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = act_lu( - x=x.astype(jnp.float32), + x=x, activation_type=activation_type, quantizer=None, ) @@ -1099,7 +1100,6 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1110,7 +1110,6 @@ def quantize_dact_dbias( activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). is_dbias: If True, compute bias gradient. Defaults to True. quantizer: Optional quantizer for FP8 quantization of the output. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: @@ -1153,19 +1152,10 @@ def quantize_dact_dbias( if is_dbias: dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) - if noop_scaled_tensor: - return ( - ScaledTensorFactory.create_2x( - output, - None, - output, - None, - ScalingMode.NO_SCALING, - dq_dtype=output.dtype, - ), - dbias, - ) - + output = NoScaleTensor( + data=output, + amax=None, + ) return output, dbias # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet @@ -1174,7 +1164,7 @@ def quantize_dact_dbias( dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None ) return _quantize_dbias_impl( - out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 ) is_gated = act_len == 2 @@ -1195,13 +1185,13 @@ def quantize_dact_dbias( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( - dz=dz.astype(jnp.float32), - x=x.astype(jnp.float32), + dz=dz, + x=x, activation_type=activation_type, quantizer=None, ) out, dbias = _quantize_dbias_impl( - out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 + out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 ) return out, dbias @@ -1265,7 +1255,6 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - noop_scale_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1275,7 +1264,6 @@ def dact_lu( x: Input tensor that was used in forward pass. activation_type: Type of activation function that was applied. quantizer: Optional quantizer for FP8 quantization of the output gradient. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: The gradient of the activation with respect to the input. @@ -1286,6 +1274,5 @@ def dact_lu( activation_type=activation_type, is_dbias=False, quantizer=quantizer, - noop_scaled_tensor=noop_scale_tensor, ) return output diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 04fcf1a8d..45d3d8b59 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -38,6 +38,7 @@ te_dtype_to_jax_dtype, get_padded_spec, get_cudnn_version, + get_all_device_compute_capability, ) from ..sharding import ( global_mesh_resource, @@ -2772,6 +2773,11 @@ def fused_attn_bwd( assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) + if 100 in get_all_device_compute_capability(): + assert not ( + attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 + ), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" + fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index bf3b3b7fd..92c09bb68 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -6,6 +6,7 @@ """JAX/TE base custom ops""" import os import re +import warnings from abc import ABCMeta, abstractmethod from functools import partial from packaging import version @@ -33,19 +34,77 @@ class BasePrimitive(metaclass=ABCMeta): name = None + _is_enabled = True + + # Default list of primitives to disable for all recipes + _default_disable_names = [] + @classmethod def enabled(cls): """ - A custom call is marked as disabled if the `cls.__name__` does not fully match the - `NVTE_JAX_CUSTOM_CALLS_RE` pattern. - This uses the Python class name of the primitive definitions that inherit from BasePrimitive. - By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names. - For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!DBiasQuantizePrimitive$).+$'` to disable `DBiasQuantizePrimitive`. + Determines if a custom call is enabled based on a state variable and environment variables. + Checks `NVTE_JAX_CUSTOM_CALLS` (key/value format) first, then falls back to the deprecated `NVTE_JAX_CUSTOM_CALLS_RE` (regex pattern), + and finally to the internal state `_is_enabled` if neither is set. + + Environment Variables: + 1. `NVTE_JAX_CUSTOM_CALLS`: Preferred key/value format to enable/disable specific primitives or a single value 'true' or 'false' to enable/disable all primitives. + - Example 1 (global enable): 'true' enables all primitives. + - Example 2 (global disable): 'false' disables all primitives. + - Example 3 (specific settings): 'DBiasQuantizePrimitive=false,GemmPrimitive=true' disables DBiasQuantizePrimitive and enables GemmPrimitive, leaving others at their default state. + Note that the default state is set at class level based on _default_disable_names. + 2. `NVTE_JAX_CUSTOM_CALLS_RE`: Deprecated regex pattern to match primitive names. + - Example: 'DBiasQuantizePrimitive' or '^(?!DBiasQuantizePrimitive$).+$' to enable/disable DBiasQuantizePrimitive. + - A deprecation warning is raised if used; it will be removed in future releases. + + Behavior: + 1. Checks if `NVTE_JAX_CUSTOM_CALLS` is set and parses key/value pairs or single true/false value. + 2. If not set, checks `NVTE_JAX_CUSTOM_CALLS_RE` (with deprecation warning) for regex matching. + 3. If neither is set, falls back to the internal state `_is_enabled`. """ - pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*") - pattern = re.compile(pattern) - is_enabled = pattern.fullmatch(cls.__name__) is not None - return is_enabled + + # Check new key/value environment variable first + custom_calls_str = os.getenv("NVTE_JAX_CUSTOM_CALLS") + if custom_calls_str is not None: + custom_calls_str = custom_calls_str.strip() + if custom_calls_str.lower() == "true": + return True + if custom_calls_str.lower() == "false": + return False + + # Parse key=value pairs + settings = {} + for pair in custom_calls_str.split(","): + pair = pair.strip() + if "=" in pair: + key, value = pair.split("=", 1) + key = key.strip() + value = value.strip().lower() + settings[key] = value == "true" + if cls.__name__ in settings: + return settings[cls.__name__] + + # Check old regex environment variable (deprecated) + pattern_str = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE") + if pattern_str is not None: + warnings.warn( + "NVTE_JAX_CUSTOM_CALLS_RE is deprecated and will be removed in future releases. Use" + " NVTE_JAX_CUSTOM_CALLS with key=value format instead (e.g.," + " 'DBiasQuantizePrimitive=false').", + DeprecationWarning, + ) + pattern = re.compile(pattern_str) + env_enabled = pattern.fullmatch(cls.__name__) is not None + return env_enabled + + # If no environment variable is set, fall back to the internal state + return cls._is_enabled + + @classmethod + def set_enabled(cls, enabled: bool): + """ + Sets the enabled state for this primitive. + """ + cls._is_enabled = enabled @staticmethod @abstractmethod @@ -78,6 +137,13 @@ def impl(): """ return NotImplemented + @classmethod + def outer_impl(cls, *args, **kwargs): + """ + to describe implementation for outer primitive + """ + return cls.impl(*args, **kwargs) + @staticmethod @abstractmethod def batcher(): @@ -112,10 +178,19 @@ def shardy_sharding_rule(*args): return "... -> ..." +# Registry to store all registered primitive classes +_primitive_registry = {} + + def register_primitive(cls): """ - register jax primitive + Register a JAX primitive and add it to the internal registry. """ + _primitive_registry[cls.__name__] = cls + + # Set default disabled state at class level based on _default_disable_names + if cls.__name__ in BasePrimitive._default_disable_names: + cls.set_enabled(False) def name_of_wrapper_p(): return cls.name + "_wrapper" @@ -131,7 +206,7 @@ def name_of_wrapper_p(): outer_p = core.Primitive(name_of_wrapper_p()) dispatch.prim_requires_devices_during_lowering.add(outer_p) outer_p.multiple_results = cls.multiple_results - outer_p.def_impl(cls.impl) + outer_p.def_impl(cls.outer_impl) outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) @@ -153,3 +228,48 @@ def name_of_wrapper_p(): for _name, _value in transformer_engine_jax.registrations().items(): ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension else "CUDA") + + +def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False): + """ + Helper function to manage primitive states by name without modifying environment variables. + Allows enabling specific primitives, disabling specific primitives, or disabling all primitives. + This helper is used in the get_quantize_config().initialize() methods. + + Args: + enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None. + disable_names: List of strings, each representing the name of a primitive class to disable. Defaults to None. + disable_all_first: Boolean, if True, disables all primitives before applying enable/disable lists. Defaults to False. + + Note: + 1. If `disable_all_first` is True, all primitives are disabled first, then `enable_names` is applied. + 2. Conflicts (a primitive in both enable and disable lists) are resolved by applying disable last. + """ + + enable_set = set(enable_names or []) + disable_set = set(disable_names or []) + + if disable_all_first: + for name, cls in _primitive_registry.items(): + if ( + isinstance(cls, type) + and issubclass(cls, BasePrimitive) + and cls is not BasePrimitive + ): + cls.set_enabled(False) + + # Apply enables + for name in enable_set: + cls = _primitive_registry.get(name) + if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive): + cls.set_enabled(True) + else: + raise ValueError(f"Primitive not found in registry: {name}") + + # Apply disables (overrides enables if there's a conflict) + for name in disable_set: + cls = _primitive_registry.get(name) + if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive): + cls.set_enabled(False) + else: + raise ValueError(f"Primitive not found in registry: {name}") diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 402ddb8fb..4ba581c66 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -10,6 +10,7 @@ from collections.abc import Iterable from typing import Tuple, Sequence, Union from functools import partial, reduce +import warnings import jax import jax.numpy as jnp @@ -26,19 +27,22 @@ from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type from ..quantize import ( + AbstractBaseTensor, + NoScaleTensor, ScaledTensor, ScaledTensor2x, GroupedScaledTensor1x, ScalingMode, Quantizer, GroupedQuantizer, - QuantizeConfig, + get_quantize_config, QuantizerSet, QuantizeLayout, noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, apply_padding_to_scale_inv, ) +from ..sharding import global_mesh_resource from .misc import get_padded_spec @@ -161,6 +165,21 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ return lhs_q, rhs_q +@partial(jax.jit, static_argnums=(1, 2)) +def swizzled_scale(scale_inv, flatten_axis, is_colwise): + "Swizzle scale_inv via JAX transpose ops" + original_shape = scale_inv.shape + shape_2d = (math.prod(original_shape[:flatten_axis]), math.prod(original_shape[flatten_axis:])) + if is_colwise: + scale_inv = jnp.transpose(scale_inv.reshape(shape_2d)) + cols, rows = shape_2d + else: + rows, cols = shape_2d + reshape = scale_inv.reshape(rows // 128, 4, 32, cols // 4, 4) + swizzled = jnp.transpose(reshape, (0, 3, 2, 1, 4)) + return swizzled.reshape(original_shape) + + class GemmPrimitive(BasePrimitive): """ Primitive for cuBLAS GEMM @@ -168,7 +187,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15) + impl_static_args = (6, 7, 8, 9, 10, 11, 12) inner_primitive = None outer_primitive = None @@ -182,16 +201,13 @@ def abstract( gelu_input, out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, ): - del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator + del use_split_accumulator def _dims_are_consecutive(dims): if len(dims) <= 1: @@ -214,27 +230,6 @@ def _dims_are_consecutive(dims): f"{rhs_contracting_dims}." ) - ( - lhs_batch_dims, - rhs_batch_dims, - ) = map(sanitize_dims, operand_ndims, batched_dims) - assert _dims_are_consecutive(lhs_batch_dims), ( - "cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got " - f"{lhs_batch_dims}." - ) - assert _dims_are_consecutive(rhs_batch_dims), ( - "cuBLAS GEMM expected consecutive batch dimensions for RHS operand, but got " - f"{rhs_batch_dims}." - ) - if len(lhs_batch_dims) == 0: - assert ( - len(rhs_batch_dims) == 0 - ), "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched." - elif len(rhs_batch_dims) != 0: - assert all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) and all( - bdim in rhs_contracting_dims for bdim in rhs_batch_dims - ), "cuBLAS GEMM batched dimensions must be contracting when both operands are batched." - lhs_contracting_size, rhs_contracting_size = map( lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), (lhs.shape, rhs.shape), @@ -263,6 +258,11 @@ def _dims_are_consecutive(dims): "require non-transposed LHS and transposed RHS operands " "(`contracting_dims=((-1, ), (-1, ))`)." ) + else: + assert lhs.dtype == rhs.dtype, ( + "For TE cuBLAS GEMM for non-quantized inputs, the operand dtypes must be equal." + f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}" + ) # Determine output shape and dtype assert ( @@ -314,28 +314,18 @@ def _dims_are_consecutive(dims): ) pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) - # Need extra workspace for swizzled scale factors - lhs_swizzle_size = 0 - rhs_swizzle_size = 0 - swizzle_dtype = jnp.uint8 - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_swizzle_size = lhs_scale_inv.size - rhs_swizzle_size = rhs_scale_inv.size - lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype) - rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype) - # Declare cuBLAS workspace # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # necessarily 256 bytes aligned, we add some padding to ensure alignment. workspace_size = get_cublas_workspace_size_bytes() + 256 workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) - return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace + return output, bias_grad, pre_gelu_out, workspace @staticmethod def outer_abstract(*args, **kwargs): outputs = GemmPrimitive.abstract(*args, **kwargs) - return outputs[:-3] # discard workspace arrays + return outputs[:-1] # discard workspace array @staticmethod def lowering( @@ -348,16 +338,14 @@ def lowering( gelu_input, out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, ): - del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype + del out_dtype + lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) lhs_transposed, rhs_transposed = _get_gemm_layout( @@ -398,33 +386,28 @@ def impl( gelu_input, out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, ): - lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) - lhs_transposed, rhs_transposed = _get_gemm_layout( - (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) - ) - lhs_scale_inv = apply_padding_to_scale_inv( - lhs_scale_inv, - scaling_mode, - lhs.shape, - is_colwise=lhs_quantized_colwise, - flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), - ) - rhs_scale_inv = apply_padding_to_scale_inv( - rhs_scale_inv, - scaling_mode, - rhs.shape, - is_colwise=rhs_quantized_colwise, - flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, - ) + if scaling_mode.is_1d_block_scaling(): + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) + ) + lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims) + rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1 + + lhs_scale_inv = apply_padding_to_scale_inv( + lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis + ) + rhs_scale_inv = apply_padding_to_scale_inv( + rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis + ) + lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) + rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) outputs = GemmPrimitive.inner_primitive.bind( lhs, @@ -435,26 +418,52 @@ def impl( gelu_input, out_dtype=out_dtype, contracting_dims=contracting_dims, - batched_dims=batched_dims, - lhs_quantized_colwise=lhs_quantized_colwise, - rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, ) - return outputs[:-3] # discard workspace arrays + return outputs[:-1] # discard workspace array + + @staticmethod + def outer_impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ): + return GemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ) @staticmethod def batcher( batched_args, - jax_batch_dims, + batch_dims, out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, @@ -462,24 +471,13 @@ def batcher( use_split_accumulator, ): assert GemmPrimitive.outer_primitive is not None - lhs, _, rhs, *_ = batched_args - lhs_bdims, _, rhs_bdims, *_ = jax_batch_dims - arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims) - arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims - assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), ( - "User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch " - f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." - ) - arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims - assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), ( - "User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch " - f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." - ) + lhs_bdims, _, rhs_bdims, *_ = batch_dims - # Output is batched like the non-contracting batch dimensions of the LHS operand - lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims) - lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims) - out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims + # Batched GEMM is not supported + assert ( + lhs_bdims is None and rhs_bdims is None + ), f"(Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims})" + out_bdims = (None,) # Bias gradient is never batched bias_bdims = (None,) @@ -494,9 +492,6 @@ def batcher( *batched_args, out_dtype=out_dtype, contracting_dims=contracting_dims, - batched_dims=batched_dims, - lhs_quantized_colwise=lhs_quantized_colwise, - rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -507,168 +502,99 @@ def batcher( ) @staticmethod - def _decompose_operand_specs(specs, contracting_dims, batch_dims): - ndims = len(specs) - cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims)) - - # Batch specs - bspecs = tuple(specs[i] for i in bdims) - - # Non-batch leading dimension specs - lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims) - - # Non-batch contracting dimension specs - cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims) + def _parse_operand_output_specs( + arg_infos, + contracting_dims, + ): + lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) - return bspecs, lspecs, cspecs + gsr = global_mesh_resource() + + # Ensure that tensor sequence parallelism is not used via setting tp_resource + if gsr.tp_resource is not None: + for i in range(len(lhs_specs) - 1): + if lhs_specs[i] == gsr.tp_resource and lhs_specs[i + 1] == gsr.tp_resource: + warnings.warn( + "Tensor sequence parallelism is detected as" + f" tp_resource='{gsr.tp_resource}' appears twice consecutively in" + f" lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource for" + " tensor sequence parallelism to avoid potential issues." + ) - @staticmethod - def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims): - lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) - lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( - sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batched_dims - ) - ( - (lhs_bspecs, lhs_lspecs, lhs_cspecs), - (rhs_bspecs, rhs_lspecs, rhs_cspecs), - ) = map( - GemmPrimitive._decompose_operand_specs, - (lhs_specs, rhs_specs), + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims) + lhs_non_cdims, rhs_non_cdims = map( + lambda ndim, cdims: tuple(i for i in range(ndim) if i not in cdims), + (lhs_ndim, rhs_ndim), (lhs_cdims, rhs_cdims), - (lhs_bdims, rhs_bdims), + ) + lhs_non_cspecs, lhs_cspecs, rhs_non_cspecs, rhs_cspecs = map( + lambda specs, dims: tuple(specs[i] for i in dims), + (lhs_specs, lhs_specs, rhs_specs, rhs_specs), + (lhs_non_cdims, lhs_cdims, rhs_non_cdims, rhs_cdims), ) - # Batched dimensions must have the same sharding - if len(lhs_bdims) > 0 and len(rhs_bdims) > 0: - assert all( - lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs) - ), ( - "cuBLAS GEMM operand batch dimensions must have the same sharding: " - f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}." + reduce_spec = None + for l in lhs_cspecs: + for r in rhs_cspecs: + if l is not None and l == r: + assert reduce_spec is None, "Multiple reduce dimension is detected!" + reduce_spec = l + + if reduce_spec is not None: + # Other non-reduce cdims (if exists) need to be unsharded + lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs) + rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs) + + # Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden + # No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim. + # In `rhs_specs`, the batch dim appears only in Wgrad GEMM under `rhs_cspecs`. + rhs_non_cspecs = tuple( + None if spec in lhs_non_cspecs else spec for spec in rhs_non_cspecs ) - # Only one each of the non-batched leading dimensions and non-batched contracting - # dimensions can be sharded - lhs_ldims, rhs_ldims = map( - lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude), - (lhs_ndim, rhs_ndim), - (lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims), - ) - (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map( - lambda specs: tuple(spec for spec in specs if spec is not None), - (lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs), - ) - assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, ( - "cuBLAS GEMM operands can have only one sharded non-batched leading dimension: " - f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}." - ) - assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, ( - "cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: " - f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}." - ) + else: + # Otherwise, require contracting dims of both operands to be unsharded + lhs_cspecs = (None,) * len(lhs_cspecs) + rhs_cspecs = (None,) * len(rhs_cspecs) + + # Non-contracting dims of RHS always needs to be gathered along the FSDP axis + rhs_non_cspecs = tuple( + None if spec is not None and spec == gsr.fsdp_resource else spec + for spec in rhs_non_cspecs + ) - # Extract single leading and contracting dimension specs - (lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map( - lambda specs: None if len(specs) == 0 else specs[0], - (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none), - ) + # Non-contracting dims of LHS to be gathered along the SP axis. + # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for + # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet. + lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs) + + out_specs = lhs_non_cspecs + rhs_non_cspecs - # Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts - # with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands. - # 1. K1 == K2 != None and N == None - # LHS: (B, M, K) - # RHS: (B, None, K) - # OUT: (B, M, None) --(AR)-> (B, M, None) - # 2. K1 == K2 != None and M == N != None - # LHS: (B, M, K) - # RHS: (B, N, K)--(AG)->(B, None, K) - # OUT: (B, M, None) --(RS)--> (B, M, N) - # 3. M == N - # LHS: (B, M, K)--(AG)->(B, M, None) - # RHS: (B, M, K)--(AG)->(B, None, None) - # OUT: (B, M, None) - # 4. M != N - # LHS: (B, M, K)--(AG)->(B, M, None) - # RHS: (B, N, K)--(AG)->(B, N, None) - # OUT: (B, M, N) - reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec - all_reduce_output = reduce_flag and rhs_lspec is None - reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec - all_reduce_spec = reduce_scatter_spec = scatter_dim = None - - lhs_non_contracting_specs, rhs_non_contracting_specs = map( - lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims), - (lhs_specs, rhs_specs), + # specs = merge(cspecs, non_cspecs) + lhs_specs, rhs_specs = map( + lambda cdims, cspecs, non_cspecs: ( + cspecs + non_cspecs if cdims[0] == 0 else non_cspecs + cspecs + ), (lhs_cdims, rhs_cdims), + (lhs_cspecs, rhs_cspecs), + (lhs_non_cspecs, rhs_non_cspecs), ) - out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs) - if reduce_scatter_output: - # All-gather (if necessary) the non-batch non-contracting dimension of RHS - # (B, N, K) --(AG)-> (B, None, K) - # (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N) - rhs_spec = tuple( - rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim) - ) - reduce_scatter_spec = lhs_cspec - scatter_dim = out_specs.index(rhs_lspec) - - elif all_reduce_output: - # Set all output trailing dimensions to zero - out_specs = ( - *lhs_non_contracting_specs, - *[None for _ in range(len(rhs_non_contracting_specs))], - ) - all_reduce_spec = lhs_cspec - else: - # All-gather (if necessary) the non-batch contracting dimensions - # (B, M, K) --(AG)-> (B, M, None) - # (B, N, K) --(AG)-> (B, N, None) - # (B, M, None) x (B, N, None)^T = (B, M, N) - lhs_specs = tuple( - None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i] - for i in range(lhs_ndim) - ) - rhs_specs = tuple( - None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i] - for i in range(rhs_ndim) - ) - # Check if RHS non-contracting spec also appears in the LHS non-contracting specs - if rhs_lspec is not None and rhs_lspec in tuple( - lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims - ): - # All-gather (if necessary) the non-batch non-contracting dimensions of RHS - # (B, N, None) --(AG)-> (B, None, None) - # (B, M, None) x (B, None, None)^T = (B, M, None) - rhs_specs = tuple( - None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i] - for i in range(rhs_ndim) - ) - # Set all output trailing dimensions to zero - out_specs = ( - *lhs_non_contracting_specs, - *[None for _ in range(len(rhs_non_contracting_specs))], - ) - # Bias and Pre-GeLU sharding is based on GEMM output - bias_specs = out_specs[len(lhs_non_contracting_specs) :] - gelu_specs = out_specs + # Bias and Pre-GeLU sharding is based on GEMM output before any scatter + bias_specs = tuple(list(rhs_non_cspecs).copy()) + gelu_specs = tuple(list(out_specs).copy()) return ( (lhs_specs, rhs_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs), - all_reduce_spec, - reduce_scatter_spec, - scatter_dim, + reduce_spec, ) @staticmethod def infer_sharding_from_operands( out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, @@ -680,15 +606,13 @@ def infer_sharding_from_operands( ): del ( out_dtype, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, grad, ) del use_split_accumulator, result_infos - (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( - GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) + (_, (out_specs, dbias_specs, pre_gelu_specs), _) = ( + GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) @@ -708,9 +632,6 @@ def infer_sharding_from_operands( def partition( out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, @@ -725,10 +646,8 @@ def partition( ( (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (out_specs, dbias_specs, pre_gelu_specs), - all_reduce_spec, - reduce_scatter_spec, - scatter_dim, - ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) + reduce_spec, + ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) # Assemble argument shardings # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. @@ -775,9 +694,6 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): gelu_input, out_dtype=out_dtype, contracting_dims=contracting_dims, - batched_dims=batched_dims, - lhs_quantized_colwise=lhs_quantized_colwise, - rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -785,19 +701,9 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): use_split_accumulator=use_split_accumulator, ) - # All-Reduce/Reduce-Scatter GEMM output - if all_reduce_spec is not None: - outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec) - if fuse_gelu and not grad: - outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec) - elif reduce_scatter_spec is not None: - outputs[0] = jax.lax.psum_scatter( - outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True - ) - if fuse_gelu and not grad: - outputs[2] = jax.lax.psum_scatter( - outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True - ) + # All-Reduce GEMM output + if reduce_spec is not None: + outputs[0] = jax.lax.psum(outputs[0], reduce_spec) return outputs @@ -807,9 +713,6 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): def shardy_sharding_rule( out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, @@ -819,40 +722,39 @@ def shardy_sharding_rule( operand_types, result_types, ): - del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator + del out_dtype, grad, use_split_accumulator del mesh, result_types prefix = "GemmPrimitive_" - def _generate_operand_rules(name, ndim, cdims, bdims): + warnings.warn( + "Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now," + " please turn off Shardy by exporting the environment variable" + " 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems." + ) + + def _generate_operand_rules(name, ndim, cdims): specs = [] - ldims = tuple(i for i in range(ndim) if i not in bdims + cdims) + ldims = tuple(i for i in range(ndim) if i not in cdims) for i in range(ndim): dim_name = None - if i in bdims: - dim_idx = bdims.index(i) if len(bdims) > 1 else "" - dim_name = f"b{dim_idx}" - elif i in cdims: - dim_idx = cdims.index(i) if len(cdims) > 1 else "" + if i in cdims: + dim_idx = cdims.index(i) dim_name = f"k{dim_idx}" else: - dim_idx = ldims.index(i) if len(ldims) > 1 else "" + dim_idx = ldims.index(i) dim_name = f"{name}_l{dim_idx}" specs.append(prefix + dim_name) return specs lhs, _, rhs, *_ = operand_types operand_ndims = (len(lhs.shape), len(rhs.shape)) - (lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = map( - lambda dims: map(sanitize_dims, operand_ndims, dims), - (contracting_dims, batched_dims), - ) + (lhs_cdims, rhs_cdims) = map(sanitize_dims, operand_ndims, contracting_dims) lhs_specs, rhs_specs = map( _generate_operand_rules, ("lhs", "rhs"), operand_ndims, (lhs_cdims, rhs_cdims), - (lhs_bdims, rhs_bdims), ) lhs_scale_specs = ("…1",) rhs_scale_specs = ("…2",) @@ -904,11 +806,10 @@ def _te_gemm( lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), - batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()), fuse_bias: bool = False, fuse_gelu: bool = False, grad: bool = False, - use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, + use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP, ) -> Tuple[jax.Array, ...]: # Prepare non-quantized GEMM operands @@ -919,7 +820,6 @@ def _te_gemm( scaling_mode = ScalingMode.NO_SCALING lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) - lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims) # Quantize operands (if necessary) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) @@ -938,7 +838,6 @@ def _te_gemm( lhs_scale_inv = lhs_q.scale_inv if lhs_q.data_layout == "T": lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) - lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis) if isinstance(rhs_q, ScaledTensor): assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( @@ -956,7 +855,6 @@ def _te_gemm( rhs_scale_inv = rhs_q.scale_inv if rhs_q.data_layout == "T": rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) - rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis) # Dummy empties for bias and gelu out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype @@ -974,9 +872,6 @@ def _te_gemm( gelu_input, out_dtype=out_dtype, contracting_dims=(lhs_cdims, rhs_cdims), - batched_dims=(lhs_bdims, rhs_bdims), - lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False, - rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -1184,10 +1079,8 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums if lhs.data_layout == "T": lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis) - lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis) if rhs.data_layout == "T": rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis) - rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis) dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) @@ -1269,7 +1162,7 @@ def _jax_gemm_fp8_impl(lhs, rhs): ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" precision = ( jax.lax.Precision.HIGHEST - if QuantizeConfig.FP8_2X_ACC_FPROP + if get_quantize_config().FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT ) return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision) @@ -1277,7 +1170,7 @@ def _jax_gemm_fp8_impl(lhs, rhs): if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) - raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") + raise NotImplementedError(f"Unsupported ScalingMode: {lhs.scaling_mode}") lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) @@ -1296,10 +1189,9 @@ def _jax_gemm_fp8_impl(lhs, rhs): def gemm( - lhs: Union[jnp.ndarray, ScaledTensor], - rhs: Union[jnp.ndarray, ScaledTensor], + lhs: Union[jnp.ndarray, AbstractBaseTensor], + rhs: Union[jnp.ndarray, AbstractBaseTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), - batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, **kwargs, @@ -1318,11 +1210,6 @@ def gemm( Object for down-casting the RHS operand for quantized GEMM. contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, )) Tuple of sequences representing the contracting dimensions of the operands. - batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()), - Tuple of sequences representing the batched dimensions of the operands. This is *not* used - to perform a batched matrix multiplication, but it is required to avoid a potentially - undesirable reduction in any batched contracting dimensions when invoked with sharded - operands (e.g. when computing weight gradients in a Flax module). bias: jax.Array, default = None Optional additive bias term, required for forward GEMM with bias fusion. Only supported with TE's custom call to cuBLAS GEMM. @@ -1340,7 +1227,8 @@ def gemm( TE's custom call to cuBLAS GEMM. use_split_accumulator: bool, default = True Enable promoting some intermediate sums to higher precision when accumulating the result in - the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. + the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only + supported with TE's custom call to cuBLAS GEMM. Returns ------- @@ -1358,6 +1246,11 @@ def gemm( compute the GeLU contribution to the gradient. Only supported with TE's custom call to cuBLAS GEMM. """ + if isinstance(lhs, NoScaleTensor): + lhs = lhs.data + if isinstance(rhs, NoScaleTensor): + rhs = rhs.data + # Try to get LHS and RHS quantizers from a quantizer set for backward compatibility if lhs_quantizer is None or rhs_quantizer is None: quantizer_set = kwargs.get("quantizer_set", None) @@ -1371,12 +1264,12 @@ def gemm( if not GemmPrimitive.enabled(): assert kwargs.get("bias", None) is None and not fuse_gelu, ( "TE GEMM was invoked with bias fusion options that are not supported by the " - "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) assert kwargs.get("gelu_input", None) is None and not fuse_bias, ( "TE GEMM was invoked with GeLU fusion options that are not supported by the " - "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) @@ -1387,7 +1280,6 @@ def gemm( lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, contracting_dims=contracting_dims, - batched_dims=batched_dims, **kwargs, ) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 0db90e384..e7464a6da 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -199,6 +199,16 @@ def get_min_device_compute_capability(): ) +def get_all_device_compute_capability(): + """ + Returns a list of compute capability of all local devices. + """ + return tuple( + transformer_engine_jax.get_device_compute_capability(local_gpu_id) + for local_gpu_id in range(len(jax.local_devices())) + ) + + def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None): """ Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 8885ae2ea..89731e24a 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -34,7 +34,7 @@ ) from .quantization import _quantize_dbias_impl from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp -from ..quantize import ScaledTensor, ScaledTensorFactory +from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, QuantizeLayout, @@ -848,9 +848,12 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) output = normed_input * gamma + beta if quantizer: + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + output = output.astype(x.dtype) ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) + ln_out = NoScaleTensor(data=ln_out, amax=None) return ln_out, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1) @@ -872,9 +875,12 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): output = normed_input * gamma if quantizer: + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + output = output.astype(x.dtype) ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) + ln_out = NoScaleTensor(data=ln_out, amax=None) return ln_out, jnp.squeeze(rsigma, axis=-1) @@ -936,7 +942,7 @@ def layernorm_fwd( scale_dtype=jnp.float32, is_outer=True, ) - return output, mu, rsigma + return NoScaleTensor(data=output, amax=None), mu, rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING @@ -1070,7 +1076,7 @@ def layernorm_bwd( ) mu_empty = jnp.zeros(mu.shape, mu.dtype) rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype) - return vjp_func((dz, mu_empty, rsigma_empty)) + return vjp_func((NoScaleTensor(data=dz, amax=None), mu_empty, rsigma_empty)) return NormBwdPrimitive.outer_primitive.bind( dz, x, @@ -1139,14 +1145,14 @@ def rmsnorm_fwd( scale_dtype=jnp.float32, is_outer=True, ) - return output, rsigma + return NoScaleTensor(data=output, amax=None), rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION ): out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None) - out, _ = _quantize_dbias_impl(out, quantizer) + out, _ = _quantize_dbias_impl(out.data, quantizer) return out, rsigma if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1158,7 +1164,9 @@ def rmsnorm_fwd( epsilon=epsilon, quantizer=None, ) - out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) + out, _ = _quantize_dbias_impl( + out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype + ) return out, rsigma is_2x2x = quantizer.is_2x2x() @@ -1260,7 +1268,7 @@ def rmsnorm_bwd( gamma, ) rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype) - return vjp_func((dz, rsigma_empty)) + return vjp_func((NoScaleTensor(data=dz, amax=None), rsigma_empty)) mu = jnp.empty(()) dx, dgamma, _ = NormBwdPrimitive.outer_primitive.bind( dz, @@ -1282,7 +1290,6 @@ def normalization_fwd( epsilon: float, norm_type: str, quantizer: Optional[Quantizer], - noop_scaled_tensor: bool = False, ): """Common wrapper for normalization forward pass. @@ -1299,7 +1306,6 @@ def normalization_fwd( - 'layernorm': Layer normalization - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: A tuple containing: @@ -1327,15 +1333,6 @@ def normalization_fwd( else: raise ValueError(f"{norm_type=} is not supported.") - if quantizer is None and noop_scaled_tensor: - return ( - ScaledTensorFactory.create_2x( - output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype - ), - mu, - rsigma, - ) - return output, mu, rsigma diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 7a5b31ad7..78780ff9c 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -6,7 +6,7 @@ """JAX/TE custom ops for quantization""" import operator from functools import reduce -from typing import Tuple, Optional +from typing import Tuple, Optional, Union import math from packaging import version @@ -41,6 +41,7 @@ QuantizeLayout, ScalingMode, compute_scale_from_amax, + NoScaleTensor, ) if version.parse(jax.__version__) >= version.parse("0.5.0"): @@ -60,13 +61,13 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): name = "te_dbias_quantize_ffi" multiple_results = True impl_static_args = ( - 2, 3, 4, 5, 6, 7, 8, + 9, ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer inner_primitive = None outer_primitive = None @@ -75,6 +76,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): def abstract( x_aval, scale_aval, + amax_aval, *, out_dtype, scaling_mode, @@ -98,7 +100,7 @@ def abstract( rowwise_out_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) - updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + updated_amax_aval = amax_aval rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode @@ -171,6 +173,7 @@ def lowering( ctx, x, scale, + amax, *, out_dtype, scaling_mode, @@ -184,13 +187,17 @@ def lowering( te_dbias_quantize_p lowering rules """ del out_dtype, scale_dtype, is_outer - x_aval, scale_aval = ctx.avals_in + x_aval, scale_aval, amax_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert scale_aval.dtype == jnp.float32 - return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)( + assert scale_aval.dtype == amax_aval.dtype == jnp.float32 + return ffi.ffi_lowering( + BaseDBiasQuantizePrimitive.name, + operand_output_aliases={2: 4}, # donate amax buffer to updated_amax + )( ctx, x, scale, + amax, scaling_mode=scaling_mode.value, q_layout=q_layout, flatten_axis=flatten_axis, @@ -201,6 +208,7 @@ def lowering( def impl( x, scale, + amax, out_dtype, scaling_mode, q_layout, @@ -225,6 +233,7 @@ def impl( ) = BaseDBiasQuantizePrimitive.inner_primitive.bind( x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -271,15 +280,15 @@ def batcher( del is_outer check_valid_batch_dims(batch_dims) assert BaseDBiasQuantizePrimitive.outer_primitive is not None - x, scale = batched_args - x_bdim, scale_bdim = batch_dims - amax_bdim = scale_bdim + x, scale, amax = batched_args + x_bdim, scale_bdim, amax_bdim = batch_dims out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim return ( BaseDBiasQuantizePrimitive.outer_primitive.bind( x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -306,7 +315,7 @@ def infer_sharding_from_operands( del (out_dtype, result_infos, scale_dtype, is_outer) # Unused. x_spec = get_padded_spec(arg_infos[0]) - scale_spec = get_padded_spec(arg_infos[1]) + amax_spec = get_padded_spec(arg_infos[2]) out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), @@ -332,10 +341,8 @@ def infer_sharding_from_operands( desc="BaseDBiasQuantizePrimitive.dbias_sharding", ) - scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -344,14 +351,14 @@ def infer_sharding_from_operands( scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv" ) - amax_sharding = NamedSharding( - mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax" - ) colwise_scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*colwise_scale_inv_spec), desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", ) + amax_sharding = NamedSharding( + mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax" + ) return ( out_sharding, @@ -378,7 +385,7 @@ def partition( del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) - scale_spec = get_padded_spec(arg_infos[1]) + amax_spec = get_padded_spec(arg_infos[2]) out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), @@ -404,10 +411,8 @@ def partition( desc="BaseDBiasQuantizePrimitive.dbias_sharding", ) - scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -435,7 +440,7 @@ def partition( dbias_sharding, ) - def sharded_impl(x, scale): + def sharded_impl(x, scale, amax): ( local_x, local_colwise_x, @@ -446,6 +451,7 @@ def sharded_impl(x, scale): ) = BaseDBiasQuantizePrimitive.impl( x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -515,7 +521,7 @@ def shardy_sharding_rule( amax = (prefix + "amax",) return SdyShardingRule( - (x_axes, ("…1",)), + (x_axes, ("…1",), amax), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), ) @@ -524,22 +530,26 @@ def shardy_sharding_rule( class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive): - """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" class QuantizePrimitive(BaseDBiasQuantizePrimitive): - """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" def _jax_quantize( x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 ): if quantizer is None: - return x + if isinstance(x, NoScaleTensor): + return x + return NoScaleTensor(data=x, amax=None) return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) -def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): +def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1): + if isinstance(dx, NoScaleTensor): + dx = dx.data sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis assert sum_axis < dx.ndim, "Flatten axis out of bounds!" dtype = dtype or dx.dtype @@ -558,7 +568,9 @@ def _jax_quantize_dbias( flatten_axis: int = -1, ): if quantizer is None: - return x, None + if isinstance(x, NoScaleTensor): + return x, None + return NoScaleTensor(data=x, amax=None), None return ( quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis), _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis), @@ -566,12 +578,11 @@ def _jax_quantize_dbias( def _quantize_dbias_impl( - x: jnp.ndarray, + x: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -581,28 +592,15 @@ def _quantize_dbias_impl( quantizer is not None ), "quantizer must be provided if dq_dtype is provided" + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + # Early-exit for non-quantized call - dq_dtype = dq_dtype or x.dtype + dq_dtype = dq_dtype or x.data.dtype if quantizer is None: dbias = None if is_dbias: - dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) - if noop_scaled_tensor: - # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor() - # always works. - return ( - ScaledTensorFactory.create_2x( - x, - None, - x, - None, - ScalingMode.NO_SCALING, - dq_dtype=x.dtype, - data_layout="NN", - flatten_axis=flatten_axis, - ), - dbias, - ) + dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return x, dbias # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, @@ -630,19 +628,26 @@ def _quantize_dbias_impl( dq_dtype=dq_dtype, flatten_axis=flatten_axis, ) - dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) + dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias scale = jnp.empty((), jnp.float32) + amax = None if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Globally reduce amax across all devices for current scaling so we have a single global scale. # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this # until the tensor is dequantized (e.g. in the GEMM). - amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32) + amax = x.amax + if amax is None: + amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,)) scale = compute_scale_from_amax(amax, quantizer.q_dtype) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale + # Make sure amax is init with zero + if amax is None: + amax = jnp.zeros((1,), jnp.float32) + # It is faster to use 1x quantization for tensor scaling is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) force_1x_quantization = ( @@ -662,8 +667,9 @@ def _quantize_dbias_impl( updated_amax, dbias, ) = PrimitiveClass.outer_primitive.bind( - x, + x.data, scale, + amax, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, q_layout=q_layout.value, @@ -702,10 +708,9 @@ def _quantize_dbias_impl( def quantize( - x: jnp.ndarray, + x: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, flatten_axis: int = -1, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -715,7 +720,6 @@ def quantize( quantizer: Quantizer for FP8 quantization of the output. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. - noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer is None. Returns: @@ -725,17 +729,15 @@ def quantize( x, quantizer=quantizer, flatten_axis=flatten_axis, - noop_scaled_tensor=noop_scaled_tensor, ) return out def quantize_dbias( - dz: jnp.ndarray, + dz: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, is_dbias: bool = True, flatten_axis: int = -1, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -746,8 +748,6 @@ def quantize_dbias( is_dbias: If True, compute bias gradient. Defaults to True. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. - noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when - quantizer is None. Returns: A tuple containing: @@ -761,7 +761,6 @@ def quantize_dbias( quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis, - noop_scaled_tensor=noop_scaled_tensor, ) @@ -936,6 +935,7 @@ def grouped_quantize( x: jnp.ndarray, quantizer: GroupedQuantizer, group_sizes: jnp.ndarray = None, + amax: jnp.ndarray = None, flatten_axis: int = -1, ) -> GroupedScaledTensor1x: """Quantize a tensor in grouped manner. @@ -948,6 +948,7 @@ def grouped_quantize( x: Input tensor to quantize quantizer: The quantizer to use for quantization group_sizes: Array of ints containing the size of each group (default: None) + amax: The amax of x; if None, it is auto-generated. (default: None) flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) Returns: @@ -962,7 +963,9 @@ def grouped_quantize( """ if quantizer is None: - return x + if isinstance(x, NoScaleTensor): + return x + return NoScaleTensor(data=x, amax=None) # TODO(Phuong): add support for flatten_axis = -2 assert flatten_axis in ( @@ -990,7 +993,10 @@ def grouped_quantize( scale = scale.at[i].set(quantizer_i.scale[0]) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: - row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) + if amax is not None: + row_amax = amax + else: + row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) segment_ids = jnp.repeat( jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] ) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index cf75c850b..17fa9906b 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -37,9 +37,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto is_2x = static_cast(is_2x_int); auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis - auto input_shape = std::vector{m, act_len * n}; - auto output_shape = std::vector{m, n}; - auto output_trans_shape = std::vector{n, m}; + auto input_shape = std::vector{m, static_cast(act_len * n)}; + auto output_shape = std::vector{m, static_cast(n)}; + auto output_trans_shape = std::vector{static_cast(n), m}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), output_shape); @@ -253,11 +253,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto m = product(act_input_dims, 0, act_input_dims.size() - 2); auto n = input_dims.back(); - auto input_shape = std::vector{m, n}; - auto act_input_shape = std::vector{m, n * act_len}; - auto output_shape = std::vector{m, n * act_len}; - auto output_trans_shape = std::vector{n * act_len, m}; - auto dbias_shape = std::vector{n * act_len}; + auto input_shape = std::vector{m, static_cast(n)}; + auto act_input_shape = std::vector{m, static_cast(n * act_len)}; + auto output_shape = std::vector{m, static_cast(n * act_len)}; + auto output_trans_shape = std::vector{static_cast(n * act_len), m}; + auto dbias_shape = std::vector{static_cast(n * act_len)}; std::vector workspace_shape(workspace_dims.begin(), workspace_dims.end()); auto input_tensor = diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 3e0842ef5..7015c2f5e 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -38,8 +38,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { } std::tuple> xla_buffer_to_nvte_gemm_operand( - cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv, - JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { + cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode, + size_t axis_boundary, bool rowwise) { // Set tensor data with collapsed 2D shape auto buffer_dims = buffer.dimensions(); std::vector input_shape = {product(buffer_dims, 0, axis_boundary), @@ -71,40 +71,6 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( } else { input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } - - // Swizzle scaling factors for MXFP8 - if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - // Get the swizzle buffer - NVTE_CHECK(swizzled_scale_inv->element_count() > 0, - "Missing swizzled inverse scale buffer in the JAX primitive."); - auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); - auto swizzled_scale_inv_dtype = - convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type()); - NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1, - "Inverse scale factors need to have an 8-bit data type."); - - // Create tensor to hold swizzled scale factor - TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); - if (rowwise) { - output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); - } else { - output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, - scale_shape); - } - - // Launch swizzle kernel - nvte_swizzle_scaling_factors(input.data(), output.data(), stream); - - // Set swizzled scales into the input tensor - if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); - } else { - input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, - scale_shape); - } - } } return std::make_tuple(std::move(input), input_shape); @@ -113,21 +79,19 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, - Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace, - JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { - // Operands (this includes swizzling MXFP8 scaling factors) // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; - auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( - stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise); - auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( - stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); + auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, scaling_mode, + lhs_axis_boundary, make_lhs_rowwise); + auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, + rhs_axis_boundary, make_rhs_rowwise); // Output tensor std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], @@ -198,8 +162,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Ret() // output .Ret() // bias_grad .Ret() // pre_gelu_out - .Ret() // lhs_swizzled - .Ret() // rhs_swizzled .Ret() // workspace .Attr("scaling_mode") .Attr("lhs_axis_boundary") @@ -295,18 +257,17 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t out_dtype_bytes = te_dtype_bytes(out_dtype); if (is_tensor_scaling) { - cudaStream_t stream_0 = nvte_get_compute_stream(0); size_t dpitch = tensor_scaling_sinv_aligment; size_t spitch = lhs_sinv_dtype_bytes; size_t width = lhs_sinv_dtype_bytes; size_t height = lhs_sinv_size; cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream_0); + cudaMemcpyDeviceToDevice, stream); spitch = rhs_sinv_dtype_bytes; width = rhs_sinv_dtype_bytes; height = rhs_sinv_size; cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream_0); + cudaMemcpyDeviceToDevice, stream); lhs_sinv_ptr = lhs_scatter_aligned_ptr; rhs_sinv_ptr = rhs_scatter_aligned_ptr; } @@ -575,10 +536,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); } - nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, - lhs_is_trans, grad, workspace_list.data(), accumulate, - use_split_accumulator, num_math_sm, stream); + nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), + pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, + grad, workspace_list.data(), accumulate, use_split_accumulator, + num_math_sm, stream); return ffi_with_cuda_error_check(); } @@ -602,8 +563,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("rhs_is_trans") .Attr("scaling_mode") .Attr("has_bias") - .Attr("is_grouped_dense_wgrad"), - GemmFFI_CudaGraph_Traits); + .Attr("is_grouped_dense_wgrad")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index b07404eb7..c35bc6668 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -118,7 +118,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{ product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1), - scale_inv_buf->dimensions().back()}); + static_cast(scale_inv_buf->dimensions().back())}); } if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { @@ -135,7 +135,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), std::vector{product(colwise_scale_inv_buf->dimensions(), 0, colwise_scale_inv_buf->dimensions().size() - 1), - colwise_scale_inv_buf->dimensions().back()}); + static_cast(colwise_scale_inv_buf->dimensions().back())}); } if (_norm_type == NVTE_Norm_Type::LayerNorm) { diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index a92934193..d17d83ec1 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -72,9 +72,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ } Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type output_trans_buf, - Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, + Buffer_Type amax_buf, Result_Type output_buf, + Result_Type output_trans_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); @@ -119,11 +120,10 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T if (is_fp8_dtype(out_dtype)) { if (is_tensor_scaling) { float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax = reinterpret_cast(amax_buf->untyped_data()); + float *amax = reinterpret_cast(updated_amax_buf->untyped_data()); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - nvte_memset(amax, 0, sizeof(float), stream); output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); output_tensor.set_rowwise_scale_inv( scale_inv_buf->untyped_data(), @@ -183,6 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Ctx() // stream .Arg() // input .Arg() // scale + .Arg() // amax .Ret() // output .Ret() // colwise output .Ret() // scale_inv @@ -410,8 +411,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Ret() // amax .Attr("scaling_mode") .Attr("q_layout") - .Attr("flatten_axis"), - FFI_CudaGraph_Traits); + .Attr("flatten_axis")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index a0fc7b7af..8087159a3 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -8,7 +8,7 @@ It implements matrix multiplication with optional bias addition and supports customizable contracting dimensions for flexible tensor operations. """ -import warnings + from typing import Tuple, Sequence from functools import partial import jax @@ -16,21 +16,44 @@ from . import cpp_extensions as tex from .quantize import ( + ScaledTensorFactory, + ScalingMode, + QuantizeLayout, QuantizerSet, noop_quantizer_set, with_sharding_constraint_by_logical_axes, + is_fp8_gemm_with_all_layouts_supported, TensorUsage, + get_quantize_config, ) -DENSE_BATCH_FIRST_WARNING_ISSUED = False +def _all_gather_kernel(kernel, mesh_axis, axis_idx): + assert mesh_axis is not None + assert 0 < axis_idx < len(kernel.shape) + + # TODO(Ming Hunag): Add a condition branch for with/without shmap. + kernel_shape = kernel.shape + kernel_whole_shape = (*kernel_shape[:axis_idx], -1, *kernel_shape[axis_idx + 1 :]) + global_kernel = jax.lax.all_gather(kernel, mesh_axis, axis=axis_idx) + global_kernel = global_kernel.reshape(*kernel_whole_shape) + return global_kernel + +def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx): + assert mesh_axis is not None + assert 0 < axis_idx < len(scattered_kernel_shape) -def _issue_batch_first_warning(msg): - global DENSE_BATCH_FIRST_WARNING_ISSUED - if not DENSE_BATCH_FIRST_WARNING_ISSUED: - warnings.warn(msg, UserWarning) - DENSE_BATCH_FIRST_WARNING_ISSUED = True + # TODO(Ming Hunag): Add a condition branch for with/without shmap. + kernel = kernel.reshape( + *scattered_kernel_shape[:axis_idx], + -1, + scattered_kernel_shape[axis_idx], + *scattered_kernel_shape[axis_idx + 1 :], + ) + kernel = jax.lax.psum_scatter(kernel, mesh_axis, scatter_dimension=axis_idx) + kernel = kernel.reshape(scattered_kernel_shape) + return kernel def dense( @@ -40,7 +63,6 @@ def dense( contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, - batch_first: bool = True, quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -54,28 +76,44 @@ def dense( kernel: Weight matrix for the dense layer transformation bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract - batch_first: Assume that X is batched in the first dimension. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ - # Remove when tex.quantize() can handle quantizer=None - if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot(): - x = with_sharding_constraint_by_logical_axes(x, input_axes) - output = tex.gemm(x, kernel, contracting_dims=contracting_dims) - if bias is not None: - bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape - output += jnp.reshape(bias, bias_new_shape) - else: - output = _dense( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set - ) + if not get_quantize_config().is_fp8_enabled(): + input_dtype = x.dtype + kernel = kernel.astype(input_dtype) + + output = _dense( + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + quantizer_set, + ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) -def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set): +@partial( + jax.custom_vjp, + nondiff_argnums=( + 3, + 4, + 5, + ), +) +def _dense( + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + quantizer_set, +): """Internal implementation of dense layer transformation with custom VJP. This function implements the core dense layer transformation logic with support @@ -89,19 +127,30 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_fir input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix quantizer_set: QuantizerSet which contains quantizers for different tensor types - batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. Returns: Transformed output tensor """ output, _ = _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + quantizer_set, ) return output def _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + quantizer_set, ): """Forward pass rule for dense layer transformation. @@ -119,28 +168,13 @@ def _dense_fwd_rule( not x_is_transposed and not k_is_transposed ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel." - # Determine X batch dimension - # - If `batch_first=True` -> (batch, leading..., contracting...) - # - Otherwise -> (leading..., batch, contracting...) - # NOTE: Always assume a single batch dimension - x_bdim = None - num_cdims = len(x_contracting_dims) - if x.ndim >= num_cdims + 2: - # Assume X is batched if it has at least +2 dimensions more than the number of contracting - # dimensions. - if not batch_first: - _issue_batch_first_warning( - "TE/JAX `dense()` layer implementation does not officially support sequence-first " - "inputs and may produce incorrect results when `batch_first=False`. Use " - "sequence-first inputs at your own discretion.", - ) - x_bdim = 0 if batch_first else x.ndim - num_cdims - 1 - flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) casted_x = tex.quantize( - x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True + x, + flatten_axis=flatten_axis_x, + quantizer=quantizer_set.x, ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) @@ -148,7 +182,6 @@ def _dense_fwd_rule( kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel, - noop_scaled_tensor=True, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -158,7 +191,6 @@ def _dense_fwd_rule( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - batched_dims=((x_bdim,), ()), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) @@ -175,13 +207,12 @@ def _dense_fwd_rule( use_bias, quantizer_set, flatten_axis_k, - x_bdim, ) return output, ctx def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad + contracting_dims, input_axes, kernel_axes, ctx, grad ): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. @@ -196,7 +227,6 @@ def _dense_bwd_rule( use_bias, quantizer_set, flatten_axis_k, - x_bdim, ) = ctx fwd_x_contracting_dims, fwd_k_contracting_dims = map( @@ -208,7 +238,6 @@ def _dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, - noop_scaled_tensor=True, ) # GEMM NT @@ -220,11 +249,11 @@ def _dense_bwd_rule( k_contracting_dim = tuple( dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims ) + dgrad = tex.gemm( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, contracting_dims=(g_contracting_dim, k_contracting_dim), - batched_dims=((x_bdim,), ()), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) @@ -238,7 +267,6 @@ def _dense_bwd_rule( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dim, g_contracting_dim), - batched_dims=((x_bdim,), (x_bdim,)), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) @@ -254,10 +282,12 @@ def grouped_dense( group_sizes: jnp.ndarray, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)), bias: jnp.ndarray = None, + kernel_amax: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, preferred_element_type: jnp.dtype = None, group_offset: jnp.array = None, quantizer_set: QuantizerSet = noop_quantizer_set, + kernel_fsdp_info: Tuple[str, int] = (None, -1), ): """ Perform grouped dense (linear) layer transformation with optional quantization. @@ -269,10 +299,15 @@ def grouped_dense( contracting_dims: Tuple of sequences specifying which dimensions to contract (currently only supports ((1,), (1,))) bias: Bias tensor of shape (G, N) + kernel_amax: The amax values of weight matrix of shape (G,) precision: JAX precision for the GEMM operation preferred_element_type: Preferred data type for the output tensor group_offset: 1D array containing offsets for each group (not yet implemented) quantizer_set: Set of quantizers for FP8 quantization of the input and output + kernel_fsdp_info: A tuple containing FSDP-related information for a weight matrix + represented in the format (str, int). The first element is the + FSDP mesh axis, and the second element is the dimension along + which the weight is sharded. Returns: A jnp.ndarray containing the result of the grouped linear operation @@ -283,25 +318,29 @@ def grouped_dense( group_sizes, contracting_dims, bias, + kernel_amax, precision, preferred_element_type, group_offset, quantizer_set, + kernel_fsdp_info, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 6, 7, 8, 10)) def _grouped_dense( x, kernel, group_sizes, contracting_dims, bias, + kernel_amax, precision, preferred_element_type, group_offset, quantizer_set, + kernel_fsdp_info, ): output, _ = _grouped_dense_fwd_rule( x, @@ -309,10 +348,12 @@ def _grouped_dense( group_sizes, contracting_dims, bias, + kernel_amax, precision, preferred_element_type, group_offset, quantizer_set, + kernel_fsdp_info, ) return output @@ -323,21 +364,31 @@ def _grouped_dense_fwd_rule( group_sizes, contracting_dims, bias, + kernel_amax, precision, preferred_element_type, group_offset, quantizer_set, + kernel_fsdp_info, ): use_bias = bias is not None is_noop_quantizer_set = quantizer_set == noop_quantizer_set + kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info + kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None + if is_noop_quantizer_set: grouped_gemm_x = x grouped_gemm_kernel = kernel ctx_x = x ctx_kernel = kernel flatten_axis_k = None + + if kernel_fsdp_enabled: + kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx) else: + original_quantizer_set_kernel_q_layout = quantizer_set.kernel.q_layout + x_contracting_dims, k_contracting_dims = contracting_dims flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis @@ -353,10 +404,24 @@ def _grouped_dense_fwd_rule( ) casted_x = tex.grouped_quantize( - x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x + x, + quantizer_set.x, + group_sizes, + flatten_axis=flatten_axis_x, ) + + ctx_kernel_usage = TensorUsage.RHS_TRANS + if kernel_fsdp_enabled: + assert quantizer_set.kernel.scaling_mode in [ + ScalingMode.CURRENT_TENSOR_SCALING, + ScalingMode.DELAYED_TENSOR_SCALING, + ] + # Perform `cast` only + ctx_kernel_usage = TensorUsage.LHS + quantizer_set.kernel.q_layout = QuantizeLayout.ROWWISE + casted_kernel = tex.grouped_quantize( - kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k + kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k ) contracting_dims = (x_contracting_dims, k_contracting_dims) @@ -364,9 +429,51 @@ def _grouped_dense_fwd_rule( # rowwise_casted_x.original_shape == (M, K) # colwise_casted_kernel.original_shape == (G, N, K) grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS) - grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS) - ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS) + ctx_kernel = casted_kernel.get_tensor(usage=ctx_kernel_usage) + + if kernel_fsdp_enabled: + ctx_kernel_in_original_shape = ctx_kernel.data.reshape(ctx_kernel.original_shape) + global_ctx_kernel_data = _all_gather_kernel( + ctx_kernel_in_original_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx + ) + kernel_shape = global_ctx_kernel_data.shape + + ctx_kernel = ScaledTensorFactory.create_1x( + global_ctx_kernel_data.reshape(-1), + ctx_kernel.scale_inv, + scaling_mode=ctx_kernel.scaling_mode, + dq_dtype=ctx_kernel.dq_dtype, + is_colwise=False, + data_layout="N", + flatten_axis=ctx_kernel.flatten_axis, + group_sizes=ctx_kernel.group_sizes, + original_shape=kernel_shape, + group_axis=ctx_kernel.group_axis, + ) + + if is_fp8_gemm_with_all_layouts_supported(): + grouped_gemm_kernel = ctx_kernel + else: + grouped_gemm_kernel_data = global_ctx_kernel_data.transpose(0, 2, 1) + grouped_gemm_kernel = ScaledTensorFactory.create_1x( + grouped_gemm_kernel_data.reshape(-1), + ctx_kernel.scale_inv, + scaling_mode=ctx_kernel.scaling_mode, + dq_dtype=ctx_kernel.dq_dtype, + is_colwise=True, + data_layout="T", + flatten_axis=ctx_kernel.flatten_axis, + group_sizes=ctx_kernel.group_sizes, + original_shape=kernel_shape, + group_axis=ctx_kernel.group_axis, + ) + else: + grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) + + # Reset quantizer_set.kernel.q_layout to align the PyTree as the given one. + # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled. + quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout output = tex.grouped_gemm( grouped_gemm_x, @@ -394,7 +501,7 @@ def _grouped_dense_fwd_rule( def _grouped_dense_bwd_rule( - contracting_dims, precision, preferred_element_type, group_offset, ctx, grad + contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad ): fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims @@ -475,11 +582,17 @@ def _grouped_dense_bwd_rule( preferred_element_type=preferred_element_type, group_offset=group_offset, ) + kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info + if kernel_fsdp_mesh_axis is not None: + wgrad = _psum_scatter_kernel( + wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx + ) group_sizes_grad = None dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None + dkernel_amax = None - return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set + return dgrad, wgrad, group_sizes_grad, dbias, dkernel_amax, quantizer_set _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 5992d3607..c548c54ef 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -15,12 +15,14 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name -from ..dense import dense, _issue_batch_first_warning as _dense_warning +from transformer_engine.common import recipe + +from ..dense import dense from ..layernorm import canonicalize_norm_type from ..layernorm import layernorm -from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning -from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning +from ..layernorm_dense import layernorm_dense +from ..layernorm_mlp import layernorm_mlp from ..activation import activation from ..softmax import softmax, SoftmaxType from ..sharding import with_sharding_constraint_by_logical_axes @@ -30,8 +32,14 @@ jax_scaled_masked_softmax, jax_scaled_upper_triang_masked_softmax, ) -from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode -from ..sharding import get_non_contracting_logical_axes +from ..quantize import ( + QuantizerFactory, + get_quantize_config, + QuantizeMeta, + QuantizeMetaSet, + ScalingMode, + TensorSource, +) PRNGKey = Any Shape = Tuple[int, ...] @@ -274,10 +282,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = False - Indicate whether the input tensors were switched axis of batch - and sequence length dimension. If set to True, the input tensors - should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). """ epsilon: float = 1e-6 @@ -288,7 +292,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = ("embed",) dtype: DType = jnp.float32 - transpose_batch_sequence: bool = False def __post_init__(self): self.scale_init = _obtain_default_layernorm_scale_init_if_need( @@ -343,29 +346,38 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method Base class of transformer engine """ - def generate_quantizer_set(self, postfix: str = ""): + def generate_quantizer_set( + self, postfix: str = "", variable_collection: str = None, fp8_recipe=None + ): """ Generate a set of FP8 meta for a GEMM. """ def generate_quantize_meta(quantizer_name: str): + collection_name = ( + variable_collection + if variable_collection is not None + else get_quantize_config().COLLECTION_NAME + ) scale = self.variable( - QuantizeConfig.COLLECTION_NAME, + collection_name, f"{quantizer_name}{postfix}_scale", jnp.ones, (1,), jnp.float32, ).value amax_history = self.variable( - QuantizeConfig.COLLECTION_NAME, + collection_name, f"{quantizer_name}{postfix}_amax_history", jnp.zeros, - (QuantizeConfig.AMAX_HISTORY_LEN,), + (get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32, ).value return QuantizeMeta(scale=scale, amax_history=amax_history) - if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: + if get_quantize_config().get_scaling_mode( + TensorSource.X + ) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling): x_meta = generate_quantize_meta("x") kernel_meta = generate_quantize_meta("kernel") grad_meta = generate_quantize_meta("grad") @@ -374,7 +386,7 @@ def generate_quantize_meta(quantizer_name: str): else: kwargs = {} - quantizer_set = QuantizerFactory.create_set(**kwargs) + quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs) return quantizer_set @@ -420,10 +432,6 @@ class DenseGeneral(TransformerEngineBase): ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = True - Indicate whether the input tensors were switched axis of batch - and sequence length dimension. If set to True, the input tensors - should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). """ features: Union[Iterable[int], int] @@ -437,16 +445,9 @@ class DenseGeneral(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 - transpose_batch_sequence: bool = False input_axes: Tuple[str, ...] = () def __post_init__(self): - if self.transpose_batch_sequence: - _dense_warning( - "TE/JAX DenseGeneral() module does not officially support sequence-first inputs " - "and may produce incorrect results when `transpose_batch_sequence=True`. Use " - "sequence-first inputs at your own discretion." - ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, "fan_in", "truncated_normal", dtype=self.dtype @@ -489,7 +490,7 @@ def __call__(self, inputs: Array) -> Array: self.dtype, ) - if not QuantizeConfig.is_fp8_enabled(): + if not get_quantize_config().is_fp8_enabled(): kernel = kernel.astype(input_dtype) if self.use_bias: @@ -628,10 +629,6 @@ class LayerNormDenseGeneral(TransformerEngineBase): ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = True - Indicate whether the input tensors were switched axis of batch - and sequence length dimension. If set to True, the input tensors - should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). depth_scaling: float, default = None The factor to scale the output from `DenseGeneral`. It should be a float value or None. When None is set, then no scaling is applied. @@ -657,18 +654,11 @@ class LayerNormDenseGeneral(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 - transpose_batch_sequence: bool = True layernorm_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None depth_scaling: float = None def __post_init__(self): - if self.transpose_batch_sequence: - _ln_dense_warning( - "TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first " - "inputs and may produce incorrect results when `transpose_batch_sequence=True`. " - "Use sequence-first inputs at your own discretion." - ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, @@ -709,7 +699,7 @@ def __call__(self, inputs: Array) -> Array: quantizer_set = self.generate_quantizer_set() fuse_layernorm = ( - QuantizeConfig.is_fp8_enabled() + get_quantize_config().is_fp8_enabled() and not self.return_layernorm_output and self.enable_layernorm ) @@ -760,7 +750,7 @@ def __call__(self, inputs: Array) -> Array: kernel_shape, self.dtype, ) - if not QuantizeConfig.is_fp8_enabled(): + if not get_quantize_config().is_fp8_enabled(): kernel = kernel.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) @@ -936,15 +926,16 @@ class LayerNormMLP(TransformerEngineBase): Indicate the logical axes of sharding constraint to the input of 2nd dot, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint. + ffn1_ckpt_name: str = "ffn1" + Checkpoint name for the output of the first fully-connected layer in the MLP block. + ffn2_ckpt_name: str = "ffn2" + Checkpoint name for the output of the second fully-connected layer in the MLP block. + Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = True - Indicate whether the input tensors were switched axis of batch - and sequence length dimension. If set to True, the input tensors - should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). """ intermediate_dim: int = 2048 @@ -973,18 +964,13 @@ class LayerNormMLP(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 - transpose_batch_sequence: bool = True layernorm_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None + ffn1_ckpt_name: str = "ffn1" + ffn2_ckpt_name: str = "ffn2" def __post_init__(self): - if self.transpose_batch_sequence: - _ln_mlp_warning( - "TE/JAX LayerNormMLP() module does not officially support sequence-first inputs " - "and may produce incorrect results when `transpose_batch_sequence=True`. Use " - "sequence-first inputs at your own discretion." - ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, "fan_in", "truncated_normal", dtype=self.dtype @@ -1026,7 +1012,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: # TODO(Phuong): use fuse_layernorm for high-precision # when NoOpQuantizer and Tensor are implemented fuse_layernorm = ( - QuantizeConfig.is_fp8_enabled() + get_quantize_config().is_fp8_enabled() and not self.return_layernorm_output and self.enable_layernorm ) @@ -1109,7 +1095,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): self.dtype, ) - if not QuantizeConfig.is_fp8_enabled(): + if not get_quantize_config().is_fp8_enabled(): kernel_1 = kernel_1.astype(input_dtype) hidden_size = inputs.shape[-1] @@ -1121,7 +1107,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): kernel_2_shape, self.dtype, ) - if not QuantizeConfig.is_fp8_enabled(): + if not get_quantize_config().is_fp8_enabled(): kernel_2 = kernel_2.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) @@ -1146,9 +1132,6 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_1 = None bias_2 = None - ffn1_ckpt_name = "ffn1" - ffn2_ckpt_name = "ffn2" - if use_fused_layernorm_mlp: out = layernorm_mlp( y, @@ -1164,8 +1147,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): dot_2_input_axes=self.dot_2_input_axes, kernel_1_axes=self.kernel_axes_1, kernel_2_axes=self.kernel_axes_2, - ffn1_ckpt_name=ffn1_ckpt_name, - ffn2_ckpt_name=ffn2_ckpt_name, + ffn1_ckpt_name=self.ffn1_ckpt_name, + ffn2_ckpt_name=self.ffn2_ckpt_name, activation_type=normalized_acts, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), ) @@ -1198,15 +1181,6 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): quantizer_set=ffn1_quantizer_set, ) - if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None: - dot_1_output_axes = ( - *get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis), - *get_non_contracting_logical_axes( - kernel_1.ndim, self.kernel_axes_1, contract_ind - ), - ) - x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes) - if self.enable_low_rank_adaptation: wi_lora_a_kernel_each_shape = ( kernel_1_each_shape[: len(axis)], @@ -1247,7 +1221,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): if self.use_bias: x += jnp.reshape(bias_1, bias_1_shape) - x = checkpoint_name(x, ffn1_ckpt_name) + x = checkpoint_name(x, self.ffn1_ckpt_name) if is_act_implemented: z = activation(x, normalized_acts) else: @@ -1310,7 +1284,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): if self.use_bias: out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) - out = checkpoint_name(out, ffn2_ckpt_name) + out = checkpoint_name(out, self.ffn2_ckpt_name) assert out.dtype == input_dtype return out, ln_output # Output, layner_norm_output diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index f2c0bc2a1..fb3ac7b9a 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -26,6 +26,7 @@ from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type from ..attention import fused_attn +from ..attention import CPStrategy from ..softmax import SoftmaxType from ..sharding import num_of_devices from ..sharding import get_sharding_map_logic_axis_to_mesh_axis @@ -274,6 +275,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me max_segments_per_seq: Optional[int] = 1 context_parallel_causal_load_balanced: bool = False context_parallel_axis: str = "" + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT + context_checkpoint_name: str = "context" @nn.compact def __call__( @@ -322,6 +325,8 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_parallel_strategy=self.context_parallel_strategy, + context_checkpoint_name=self.context_checkpoint_name, ) elif self.qkv_layout.is_kvpacked(): """kvpacked format, treat @@ -348,6 +353,8 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_parallel_strategy=self.context_parallel_strategy, + context_checkpoint_name=self.context_checkpoint_name, ) elif self.qkv_layout.is_separate(): if self.transpose_batch_sequence: @@ -369,6 +376,8 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_parallel_strategy=self.context_parallel_strategy, + context_checkpoint_name=self.context_checkpoint_name, ) else: raise ValueError(f"Unsupported {self.qkv_layout=}.") @@ -501,6 +510,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. + context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING. + context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention. Optimization parameters ----------------------- @@ -524,6 +535,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods max_segments_per_seq: Optional[int] = 1 context_parallel_causal_load_balanced: bool = False context_parallel_axis: str = "" + context_parallel_strategy: str = "DEFAULT" + context_checkpoint_name: str = "context" @nn.compact def __call__( @@ -642,6 +655,24 @@ def __call__( scale_factor = self.scale_factor del self.scale_factor + # case-insensitive mapping for context parallel strategy + cp_strategy_map = { + "DEFAULT": CPStrategy.DEFAULT, + "ALL_GATHER": CPStrategy.ALL_GATHER, + "ALLGATHER": CPStrategy.ALL_GATHER, # Alternative spelling + "RING": CPStrategy.RING, + } + + strategy_key = self.context_parallel_strategy.upper() + if strategy_key in cp_strategy_map: + context_parallel_strategy = cp_strategy_map[strategy_key] + else: + valid_strategies = list(cp_strategy_map.keys()) + raise ValueError( + f"Invalid context parallel strategy: {self.context_parallel_strategy}. " + f"Valid options are: {valid_strategies} (case insensitive)" + ) + if not use_fused_attn: # unfused attention only supports splitted query, key, value if qkv_layout.is_qkvpacked(): @@ -690,6 +721,8 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_parallel_strategy=context_parallel_strategy, + context_checkpoint_name=self.context_checkpoint_name, )( query, key, @@ -1160,7 +1193,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): epsilon=self.layernorm_epsilon, axis=-1, features=(3, self.num_attention_heads * self.head_dim), - transpose_batch_sequence=self.transpose_batch_sequence, return_layernorm_output=self.return_layernorm_output, scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), @@ -1187,7 +1219,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): epsilon=self.layernorm_epsilon, axis=-1, features=self.num_attention_heads * self.head_dim, - transpose_batch_sequence=self.transpose_batch_sequence, return_layernorm_output=(self.return_layernorm_output or is_self_attn), scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), @@ -1212,7 +1243,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): kv_proj = DenseGeneral( axis=-1, features=(2, self.num_gqa_groups * self.head_dim), - transpose_batch_sequence=self.transpose_batch_sequence, kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_init=kv_init, use_bias=self.use_bias, @@ -1231,7 +1261,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): DenseGeneral, axis=-1, features=self.num_gqa_groups * self.head_dim, - transpose_batch_sequence=self.transpose_batch_sequence, kernel_axes=(W_FSDP_AXES, W_TP_AXES), use_bias=self.use_bias, bias_init=self.bias_init, @@ -1248,7 +1277,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): epsilon=self.layernorm_epsilon, axis=-1, features=self.num_attention_heads * self.head_dim, - transpose_batch_sequence=self.transpose_batch_sequence, return_layernorm_output=True, scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), @@ -1413,7 +1441,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): out = DenseGeneral( features=inputs_q.shape[-1], - transpose_batch_sequence=self.transpose_batch_sequence, axis=-1, kernel_init=self.kernel_init, kernel_axes=(W_TP_AXES, W_FSDP_AXES), @@ -2015,7 +2042,6 @@ def hidden_dropout(x, deterministic): layernorm_type=self.layernorm_type, zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon, - transpose_batch_sequence=self.transpose_batch_sequence, return_layernorm_output=self.apply_residual_connection_post_layernorm, intermediate_dim=self.mlp_hidden_size, activations=self.mlp_activations, @@ -2070,7 +2096,6 @@ def hidden_dropout(x, deterministic): epsilon=self.layernorm_epsilon, scale_axes=(W_NO_SHARD_AXES,), bias_axes=(W_NO_SHARD_AXES,), - transpose_batch_sequence=self.transpose_batch_sequence, dtype=self.dtype, name="output_layernorm", )(z) diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index 7a3ad597b..0f5c6aeef 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -17,7 +17,6 @@ from . import cpp_extensions as tex from .quantize import ( - ScaledTensor, Quantizer, ) @@ -112,8 +111,8 @@ def _layernorm_fwd_rule(x, gamma, beta, norm_type: str, zero_centered_gamma, eps output, mu, rsigma = tex.normalization_fwd( x, gamma, beta, zero_centered_gamma, epsilon, norm_type, quantizer ) - if isinstance(output, ScaledTensor): - output = output.dequantize() + # This is a no-op for higher-precision tensors + output = output.dequantize() return output, (x, mu, rsigma, gamma, beta, quantizer) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 5ccfc71c2..fb9783075 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -9,7 +9,6 @@ distributed training through sharding constraints. """ -import warnings from functools import partial from typing import Tuple @@ -23,19 +22,10 @@ noop_quantizer_set, with_sharding_constraint_by_logical_axes, TensorUsage, + get_quantize_config, ) -LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False - - -def _issue_batch_first_warning(msg): - global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED - if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED: - warnings.warn(msg, UserWarning) - LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True - - def layernorm_dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -48,7 +38,6 @@ def layernorm_dense( layernorm_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, - batch_first: bool = True, quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: """Apply layer normalization followed by dense layer transformation. @@ -69,7 +58,6 @@ def layernorm_dense( layernorm_input_axes: Logical axes for sharding the layernorm input dot_input_axes: Logical axes for sharding the matrix multiplication input kernel_axes: Logical axes for sharding the weight matrix - batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. quantizer_set: Set of quantizers for different tensor types Returns: @@ -81,6 +69,11 @@ def layernorm_dense( - The function supports automatic differentiation through JAX's custom VJP - Quantization is applied to both the normalized input and kernel """ + + if not get_quantize_config().is_fp8_enabled(): + input_dtype = x.dtype + kernel = kernel.astype(input_dtype) + output = _layernorm_dense( x, kernel, @@ -93,7 +86,6 @@ def layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, - batch_first, quantizer_set, ) return output @@ -108,7 +100,6 @@ def layernorm_dense( 8, 9, 10, - 11, ), ) def _layernorm_dense( @@ -123,7 +114,6 @@ def _layernorm_dense( layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...], - batch_first: bool, quantizer_set, ): """Internal implementation of layernorm_dense with custom VJP. @@ -143,7 +133,6 @@ def _layernorm_dense( epsilon: Small constant for numerical stability layernorm_input_axes: Logical axes for layernorm sharding dot_input_axes: Logical axes for matrix multiplication sharding - batch_first: Assume that X is batched in the first dimension. quantizer_set: Set of quantizers Returns: @@ -161,7 +150,6 @@ def _layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, - batch_first, quantizer_set, ) return output @@ -179,7 +167,6 @@ def _layernorm_dense_fwd_rule( layernorm_input_axes, dot_input_axes, kernel_axes, - batch_first, quantizer_set, ): """Forward pass rule for layernorm_dense. @@ -197,17 +184,6 @@ def _layernorm_dense_fwd_rule( k_contracting_dims = (0,) assert x.shape[-1] == kernel.shape[0] - x_bdim = None - if x.ndim > 2: - if not batch_first: - _issue_batch_first_warning( - "TE/JAX `layernorm_dense()` fused-layer implementation does not officially " - "support sequence-first inputs and may produce incorrect results when " - "`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first " - "inputs at your own discretion." - ) - x_bdim = 0 if batch_first else x.ndim - 2 - x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) casted_ln_out, mu, rsigma = tex.normalization_fwd( @@ -218,14 +194,15 @@ def _layernorm_dense_fwd_rule( epsilon, norm_type, quantizer=quantizer_set.x, - noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) # Kernel in (hidden_in, hidden_out...) flatten_axis = 1 - len(kernel.shape) casted_kernel = tex.quantize( - kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True + kernel, + flatten_axis=flatten_axis, + quantizer=quantizer_set.kernel, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -236,7 +213,6 @@ def _layernorm_dense_fwd_rule( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - batched_dims=((x_bdim,), ()), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) @@ -260,7 +236,6 @@ def _layernorm_dense_fwd_rule( use_bias, quantizer_set, flatten_axis, - x_bdim, ) return output, ctx @@ -271,9 +246,8 @@ def _layernorm_dense_bwd_rule( zero_centered_gamma, epsilon, layernorm_input_axes, - dot_input_axes, # pylint: disable=unused-argument + dot_input_axes, kernel_axes, - batch_first, # pylint: disable=unused-argument ctx, grad, ): @@ -288,6 +262,7 @@ def _layernorm_dense_bwd_rule( Returns: Tuple of gradients for all input parameters """ + del dot_input_axes ( casted_ln_out, casted_kernel, @@ -303,7 +278,6 @@ def _layernorm_dense_bwd_rule( use_bias, quantizer_set, flatten_axis, - x_bdim, ) = ctx casted_grad, dbias = tex.quantize_dbias( @@ -311,7 +285,6 @@ def _layernorm_dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad, - noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim @@ -328,7 +301,6 @@ def _layernorm_dense_bwd_rule( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, contracting_dims=(g_constracting_dim, k_constracting_dim), - batched_dims=((x_bdim,), ()), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) @@ -342,7 +314,6 @@ def _layernorm_dense_bwd_rule( casted_ln_out, casted_grad.get_tensor(TensorUsage.RHS), contracting_dims=(x_constracting_dim, g_constracting_dim), - batched_dims=((x_bdim,), (x_bdim,)), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 507c49c7e..fc957801a 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -13,7 +13,6 @@ quantization, and distributed training through sharding constraints. """ -import warnings from typing import List, Tuple, Sequence, Union, Callable from functools import partial @@ -28,18 +27,8 @@ QuantizerSet, noop_quantizer_set, TensorUsage, + get_quantize_config, ) -from .sharding import get_non_contracting_logical_axes - - -LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False - - -def _issue_batch_first_warning(msg): - global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED - if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED: - warnings.warn(msg, UserWarning) - LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True def layernorm_mlp( @@ -59,7 +48,6 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), - batch_first: bool = True, quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -91,7 +79,6 @@ def layernorm_mlp( ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation - batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations Returns: @@ -118,6 +105,11 @@ def layernorm_mlp( not zero_centered_gamma ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" + if not get_quantize_config().is_fp8_enabled(): + input_dtype = x.dtype + kernel_1 = kernel_1.astype(input_dtype) + kernel_2 = kernel_2.astype(input_dtype) + output = _layernorm_mlp( x, gamma, @@ -137,13 +129,12 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, - batch_first, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -163,7 +154,6 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], - batch_first: bool, quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -189,7 +179,6 @@ def _layernorm_mlp( ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing activation_type: Activation function(s) - batch_first: Assume that X is batched in the first dimension. quantizer_sets: Tuple of quantizer sets Returns: @@ -214,7 +203,6 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, - batch_first, quantizer_sets, ) return output @@ -239,7 +227,6 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, - batch_first, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -256,7 +243,7 @@ def _layernorm_mlp_fwd_rule( Returns: Tuple of (output, context) for automatic differentiation """ - del kernel_2_axes + del kernel_1_axes, kernel_2_axes ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets @@ -272,17 +259,6 @@ def _layernorm_mlp_fwd_rule( assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] - x_bdim = None - if x.ndim > 2: - if not batch_first: - _issue_batch_first_warning( - "TE/JAX `layernorm_mlp()` fused-layer implementation does not officially " - "support sequence-first inputs and may produce incorrect results when " - "`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first " - "inputs at your own discretion." - ) - x_bdim = 0 if batch_first else x.ndim - 2 - use_bias_1 = bias_1 is not None use_bias_2 = bias_1 is not None @@ -296,12 +272,13 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, - noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_kernel_1 = tex.quantize( - kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True + kernel_1, + flatten_axis=-2, + quantizer=ffn1_quantizer_set.kernel, ) # NN GEMM @@ -310,34 +287,36 @@ def _layernorm_mlp_fwd_rule( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - batched_dims=((x_bdim,), ()), bias=bias_1 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, ) - if dot_1_input_axes is not None and kernel_1_axes is not None: - dot_1_output_axes = ( - *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims), - *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims), - ) - dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) - if use_bias_1 and tex.gemm_uses_jax_dot(): bias_1_shape = bias_1.shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) + # This sharding constraint is needed to correct the Shardy sharding propagation + if dot_2_input_axes is not None: + dot_1_output_axes = ( + dot_2_input_axes[:-1] + (None,) + dot_2_input_axes[-1:] + ) # add the act_num axis + dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) + dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # (batch..., hidden_in) -> (batch..., hidden) casted_act_out = tex.act_lu( - dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True + dot_1_output, + activation_type, + quantizer=ffn2_quantizer_set.x, ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_kernel_2 = tex.quantize( - kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True + kernel_2, + quantizer=ffn2_quantizer_set.kernel, ) # NN GEMM @@ -346,7 +325,6 @@ def _layernorm_mlp_fwd_rule( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - batched_dims=((x_bdim,), ()), bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, ) @@ -376,7 +354,6 @@ def _layernorm_mlp_fwd_rule( use_bias_1, use_bias_2, quantizer_sets, - x_bdim, ) return dot_2_output, ctx @@ -394,7 +371,6 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, - batch_first, ctx, grad, ): @@ -411,7 +387,7 @@ def _layernorm_mlp_bwd_rule( Returns: Tuple of gradients for all input parameters """ - del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first + del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name ( x, mu, @@ -430,7 +406,6 @@ def _layernorm_mlp_bwd_rule( use_bias_1, use_bias_2, quantizer_sets, - x_bdim, ) = ctx ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets @@ -439,7 +414,9 @@ def _layernorm_mlp_bwd_rule( grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) casted_grad, dbias_2 = tex.quantize_dbias( - grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True + grad, + is_dbias=use_bias_2, + quantizer=ffn1_quantizer_set.dgrad, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -457,7 +434,6 @@ def _layernorm_mlp_bwd_rule( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), - batched_dims=((x_bdim,), ()), ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -472,7 +448,6 @@ def _layernorm_mlp_bwd_rule( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), - batched_dims=((x_bdim,), (x_bdim,)), ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -482,7 +457,6 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, - noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -500,7 +474,6 @@ def _layernorm_mlp_bwd_rule( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), - batched_dims=((x_bdim,), ()), ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) @@ -511,7 +484,6 @@ def _layernorm_mlp_bwd_rule( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), - batched_dims=((x_bdim,), (x_bdim,)), ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 0b9659a46..4037eae80 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -9,9 +9,11 @@ This module provides configuration and helper functions for managing quantization metadata in JAX, including support for different scaling modes and datatypes. """ +from abc import ABC, abstractmethod from contextlib import contextmanager +from dataclasses import dataclass from enum import Enum -from typing import Optional, Tuple, Dict, Union, Sequence +from typing import Optional, Tuple, Dict, Union, Sequence, Type from functools import reduce import operator @@ -35,7 +37,7 @@ from .device_utils import get_device_compute_capability __all__ = [ - "QuantizeConfig", + "get_quantize_config", "fp8_autocast", "is_fp8_available", "update_collections", @@ -43,12 +45,15 @@ "apply_padding_to_scale_inv", "remove_padding_from_scale_inv", "NVTE_FP8_COLLECTION_NAME", + "TensorSource", ] _is_fp8_available = None _reason_for_no_fp8 = "" Collection = Union[Dict, FrozenDict] +NVTE_FP8_COLLECTION_NAME = "fp8_metas" + def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: """Check if delayed scaling FP8 is supported on the given GPU architecture. @@ -170,6 +175,17 @@ def _format2dtypes(format_: recipe.Format): return jnp.bfloat16, jnp.bfloat16 +class TensorSource(Enum): + """Enumeration for where a tensor's data comes from.""" + + # Input data + X = 0 + # Model parameters + KERNEL = 1 + # Gradients in the backward pass + DGRAD = 2 + + class AmaxComputeAlgo(Enum): """Enumeration for AMAX computation algorithms. @@ -182,28 +198,8 @@ class AmaxComputeAlgo(Enum): MOST_RECENT = "most_recent" -def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: - """Convert recipe.Recipe to ScalingMode. - - Args: - fp8_recipe: The FP8 recipe to convert - - Returns: - The corresponding ScalingMode - - Raises: - ValueError: If the recipe type is not supported - """ - if isinstance(fp8_recipe, recipe.DelayedScaling): - return ScalingMode.DELAYED_TENSOR_SCALING - if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - return ScalingMode.MXFP8_1D_SCALING - if isinstance(fp8_recipe, recipe.Float8CurrentScaling): - return ScalingMode.CURRENT_TENSOR_SCALING - raise ValueError("Invalid fp8_recipe!") - - -class QuantizeConfig: +@dataclass +class BaseQuantizeConfig(ABC): """Configuration class for quantization settings. This class manages global quantization settings including FP8 formats, @@ -220,14 +216,13 @@ class QuantizeConfig: FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients INFERENCE_MODE: Whether to enable optimization for inference - SCALING_MODE: Scaling mode AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling AMAX_COMPUTE_ALGO: Algorithm for AMAX computation """ INITIALIZED = False MARGIN: float = 0.0 - COLLECTION_NAME: str = "fp8_metas" + COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME FP8_FORMAT: recipe.Format = recipe.Format.HYBRID FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0] BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1] @@ -235,61 +230,82 @@ class QuantizeConfig: FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False INFERENCE_MODE: bool = False - SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING # DelayedScaling AMAX_HISTORY_LEN: int = 1024 AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX - @staticmethod - def is_fp8_enabled(): + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + """Initialize the quantization configuration. + + Args: + fp8_recipe: The FP8 recipe to use for initialization + """ + self.INITIALIZED = True + self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 + self.FP8_FORMAT = fp8_recipe.fp8_format + self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(self.FP8_FORMAT) + + def is_fp8_enabled(self) -> bool: """Check if FP8 quantization is enabled. Returns: bool: True if quantization is enabled, False otherwise """ - return QuantizeConfig.INITIALIZED + return self.INITIALIZED - @classmethod - def initialize(cls, fp8_recipe: recipe.Recipe) -> None: - """Initialize the quantization configuration. + @abstractmethod + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type. Args: - fp8_recipe: The FP8 recipe to use for initialization + tensor_source: The usage type for which to get the scaling mode. + + Returns: + The scaling mode for the specified usage type. + """ + + def is_supported(self) -> tuple[bool, str]: + """Check if this QuantizeConfig class is supported on the available devices. + + Returns: + bool: True if the class is supported, False otherwise + str: Reason for being unsupported, if applicable. """ - cls.INITIALIZED = True - cls.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 - cls.FP8_FORMAT = fp8_recipe.fp8_format - cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) - cls.SCALING_MODE = _get_scaling_mode(fp8_recipe) - - @classmethod - def finalize(cls) -> None: - """Reset the quantization configuration to default values.""" - cls.INITIALIZED = False - cls.MARGIN = 0.0 - cls.FP8_FORMAT = recipe.Format.HYBRID - cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) - cls.SCALING_MODE = ScalingMode.NO_SCALING - cls.FP8_2X_ACC_FPROP = False - cls.FP8_2X_ACC_DGRAD = False - cls.FP8_2X_ACC_WGRAD = False - cls.SCALING_MODE = ScalingMode.NO_SCALING - cls.INFERENCE_MODE = False - # DelayedScaling - cls.AMAX_HISTORY_LEN = 1024 - cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX - - -class DelayedScalingQuantizeConfig: + + x_scaling_mode = self.get_scaling_mode(TensorSource.X) + kernel_scaling_mode = self.get_scaling_mode(TensorSource.KERNEL) + grad_scaling_mode = self.get_scaling_mode(TensorSource.DGRAD) + for scaling_mode in [x_scaling_mode, kernel_scaling_mode, grad_scaling_mode]: + is_supported, reason = is_fp8_available(scaling_mode=scaling_mode) + if not is_supported: + return is_supported, reason + return True, None + + +class NoOpQuantizeConfig(BaseQuantizeConfig): + """Configuration class higher-precision non-quantized operation.""" + + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + """Initialize no-op configuration.""" + raise NotImplementedError( + "NoOpQuantizeConfig cannot be initialize from a recipe as it represents" + " higher-precision when no quantized recipe is set." + ) + + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + return ScalingMode.NO_SCALING + + +class DelayedScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for delayed scaling FP8 recipe. This class provides specific initialization and finalization for delayed scaling FP8 quantization mode. """ - @staticmethod - def initialize(fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: """Initialize delayed scaling FP8 configuration. Args: @@ -298,6 +314,8 @@ def initialize(fp8_recipe: recipe.Recipe) -> None: Raises: AssertionError: If recipe parameters are not supported """ + super().initialize_from_recipe(fp8_recipe) + assert fp8_recipe.amax_compute_algo in [ "max", "most_recent", @@ -307,71 +325,88 @@ def initialize(fp8_recipe: recipe.Recipe) -> None: ), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX." assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX." - cls = QuantizeConfig - cls.initialize(fp8_recipe) - - cls.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len + self.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len string_to_amax_compute_algo = { "max": AmaxComputeAlgo.MAX, "most_recent": AmaxComputeAlgo.MOST_RECENT, } - cls.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo] + self.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo] - cls.FP8_2X_ACC_DGRAD = True - cls.FP8_2X_ACC_WGRAD = True + self.FP8_2X_ACC_DGRAD = True + self.FP8_2X_ACC_WGRAD = True - @staticmethod - def finalize() -> None: - """Reset the delayed scaling configuration.""" - QuantizeConfig.finalize() + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + return ScalingMode.DELAYED_TENSOR_SCALING -class CurrentScalingQuantizeConfig: +class CurrentScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for current scaling FP8 recipe. This class provides specific initialization and finalization for current scaling FP8 quantization mode. """ - @staticmethod - def initialize(fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: """Initialize current scaling FP8 configuration. Args: fp8_recipe: The FP8 recipe to use for initialization """ - cls = QuantizeConfig - cls.initialize(fp8_recipe) - cls.AMAX_HISTORY_LEN = 0 + super().initialize_from_recipe(fp8_recipe) + self.AMAX_HISTORY_LEN = 0 - @staticmethod - def finalize() -> None: - """Reset the current scaling configuration.""" - QuantizeConfig.finalize() + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + return ScalingMode.CURRENT_TENSOR_SCALING -class BlockScalingQuantizeConfig: +class BlockScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for block scaling FP8 recipe. This class provides specific initialization and finalization for block scaling FP8 quantization mode. """ - @staticmethod - def initialize(fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: """Initialize block scaling FP8 configuration. Args: fp8_recipe: The FP8 recipe to use for initialization """ - cls = QuantizeConfig - cls.initialize(fp8_recipe) - cls.AMAX_HISTORY_LEN = 0 + super().initialize_from_recipe(fp8_recipe) + self.AMAX_HISTORY_LEN = 0 + + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + return ScalingMode.MXFP8_1D_SCALING + + +_QUANTIZE_CONFIG = NoOpQuantizeConfig() - @staticmethod - def finalize() -> None: - """Reset the block scaling configuration.""" - QuantizeConfig.finalize() + +def get_quantize_config(): + """Global instance of BaseQuantizeConfig set by fp8_autocast context.""" + return _QUANTIZE_CONFIG + + +def get_quantize_config_class( + fp8_recipe: recipe.Recipe, +) -> Type[BaseQuantizeConfig]: + """Get the quantization configuration based on the FP8 recipe. + + Args: + fp8_recipe: The FP8 recipe to use for initialization + Returns: + The quantization config class corresponding to the given recipe. + """ + if isinstance(fp8_recipe, recipe.DelayedScaling): + return DelayedScalingQuantizeConfig + if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + return BlockScalingQuantizeConfig + if isinstance(fp8_recipe, recipe.Float8CurrentScaling): + return CurrentScalingQuantizeConfig + raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}") @contextmanager @@ -420,25 +455,22 @@ def fp8_autocast( if fp8_recipe is None: fp8_recipe = recipe.DelayedScaling() - if mesh_resource is None: - mesh_resource = MeshResource() + global _QUANTIZE_CONFIG - Config = DelayedScalingQuantizeConfig - if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - Config = BlockScalingQuantizeConfig - if isinstance(fp8_recipe, recipe.Float8CurrentScaling): - Config = CurrentScalingQuantizeConfig + old_quantize_config = _QUANTIZE_CONFIG + + _QUANTIZE_CONFIG = NoOpQuantizeConfig() try: with global_shard_guard(mesh_resource): if enabled: - fp8_available, reason_for_no_fp8 = is_fp8_available(_get_scaling_mode(fp8_recipe)) - assert fp8_available, reason_for_no_fp8 - - Config.initialize(fp8_recipe) + _QUANTIZE_CONFIG = get_quantize_config_class(fp8_recipe)() + is_supported, reason = _QUANTIZE_CONFIG.is_supported() + assert is_supported, reason + _QUANTIZE_CONFIG.initialize_from_recipe(fp8_recipe) yield finally: - Config.finalize() + _QUANTIZE_CONFIG = old_quantize_config def get_delayed_scaling(): @@ -456,12 +488,12 @@ def get_delayed_scaling(): an instance of DelayedScaling which is set via fp8_autocast. """ amax_compute_algo = ( - "max" if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent" + "max" if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent" ) return recipe.DelayedScaling( - margin=int(QuantizeConfig.MARGIN), - fp8_format=QuantizeConfig.FP8_FORMAT, - amax_history_len=QuantizeConfig.AMAX_HISTORY_LEN, + margin=int(get_quantize_config().MARGIN), + fp8_format=get_quantize_config().FP8_FORMAT, + amax_history_len=get_quantize_config().AMAX_HISTORY_LEN, amax_compute_algo=amax_compute_algo, ) @@ -600,6 +632,3 @@ def apply_padding_to_scale_inv( # Pad the scales with the lowest representable value (2^-127) and return pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape)) return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127) - - -NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 881f3a74b..306603bbe 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -16,12 +16,21 @@ import jax.numpy as jnp from jax.tree_util import register_pytree_node_class from transformer_engine_jax import QuantizeLayout +from transformer_engine.common import recipe from .scaling_modes import ScalingMode -from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory +from .tensor import ( + ScaledTensor, + ScaledTensor1x, + ScaledTensor2x, + ScaledTensorFactory, + NoScaleTensor, +) from .helper import ( - QuantizeConfig, + get_quantize_config, + get_quantize_config_class, AmaxComputeAlgo, + TensorSource, ) from .device_utils import is_fp8_gemm_with_all_layouts_supported @@ -54,7 +63,7 @@ def compute_scale_from_amax( fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32) if scale is None: scale = jnp.ones((1,)) - sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN) + sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) return sf @@ -214,7 +223,11 @@ class CurrentScaleQuantizer(Quantizer): data_layout: str = "NT" def _quantize_func( - self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1 + self, + x: Union[jnp.ndarray, NoScaleTensor], + is_colwise=False, + dq_dtype=None, + flatten_axis=-1, ) -> ScaledTensor1x: """Quantize function helper for delayed scaling FP8. @@ -226,14 +239,17 @@ def _quantize_func( Returns: A ScaledTensor1x containing the quantized data """ - dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + + dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype compute_dtype = jnp.float32 dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) - amax = jnp.max(jnp.abs(x)).reshape((1,)) + amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) - scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN) - scaled_x = x.astype(compute_dtype) * scale + scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) + scaled_x = x.data.astype(compute_dtype) * scale clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) scale_inv = 1.0 / scale @@ -260,7 +276,10 @@ def quantize( Returns: A ScaledTensor1x or ScaledTensor2x containing the quantized data """ - dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + + dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype if flatten_axis < 0: flatten_axis += x.ndim assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" @@ -318,7 +337,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) amax_history: jnp.ndarray = field( - default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32) + default_factory=lambda: jnp.zeros((get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32) ) def tree_flatten(self): @@ -344,11 +363,14 @@ def _quantize_func( Returns: A ScaledTensor1x containing the quantized data """ - dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + + dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype compute_dtype = jnp.float32 dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) - scaled_x = x.astype(compute_dtype) * self.scale + scaled_x = x.data.astype(compute_dtype) * self.scale # quantize() in the old dot.py do this way, leave this code block here for future debugging # compute_dtype = x.dtype @@ -357,7 +379,8 @@ def _quantize_func( clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) scale_inv = 1.0 / self.scale - self.update(jnp.max(jnp.abs(x)).reshape((1,))) + amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) + self.update(amax) return ScaledTensorFactory.create_1x( data=clipped_scaled_x, scale_inv=scale_inv, @@ -395,7 +418,7 @@ def _compute_scale(amax_history, scale, q_dtype): Updated scale value """ # 2. Calculate the current scale - if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: + if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: amax = jnp.max(amax_history, axis=-1, keepdims=True) else: amax = amax_history[0:1] @@ -457,6 +480,10 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> Returns: A ScaledTensor1x containing the quantized data """ + if isinstance(x, NoScaleTensor): + # No need for amax in MXFP8 block scaling, so simply extract the jnp.ndarray data tensor from the NoScaleTensor x. + x = x.data + # TODO(Phuong): use quantize_func from JAX if flatten_axis < 0: flatten_axis = x.ndim + flatten_axis @@ -492,7 +519,7 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> return ScaledTensorFactory.create_1x( x_q, scales_q, - self.scaling_mode, + scaling_mode=self.scaling_mode, is_colwise=is_colwise, dq_dtype=dq_dtype, flatten_axis=flatten_axis, @@ -638,11 +665,11 @@ def _create_grouped_tensor_from_tensor_list( return ScaledTensorFactory.create_1x( grouped_data, grouped_scale_inv, - self.scaling_mode, - tensor_list[0].dq_dtype, - tensor_list[0].is_colwise, - tensor_list[0].data_layout, - tensor_list[0].flatten_axis, + scaling_mode=self.scaling_mode, + dq_dtype=tensor_list[0].dq_dtype, + is_colwise=tensor_list[0].is_colwise, + data_layout=tensor_list[0].data_layout, + flatten_axis=tensor_list[0].flatten_axis, group_sizes=group_sizes, original_shape=original_shape, group_axis=group_axis, @@ -825,12 +852,21 @@ def create( @staticmethod def _create_set( - scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs + x_scaling_mode, + kernel_scaling_mode, + grad_scaling_mode, + fwd_dtype, + bwd_dtype, + is_2x2x, + n_groups, + **kwargs, ) -> QuantizerSet: """Create a set of quantizers for forward and backward passes. Args: - scaling_mode: Scaling mode to use + x_scaling_mode: Scaling mode to use for input tensor 'x' + kernel_scaling_mode: Scaling mode to use for kernel tensor + grad_scaling_mode: Scaling mode to use for gradient tensor fwd_dtype: Data type for forward pass bwd_dtype: Data type for backward pass is_2x2x: Whether to use 2x2x quantization @@ -844,9 +880,9 @@ def _create_set( q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE else: q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE - if scaling_mode.is_1d_block_scaling(): + if kernel_scaling_mode.is_1d_block_scaling(): q_layout_kernel = QuantizeLayout.COLWISE - if QuantizeConfig.INFERENCE_MODE: + if get_quantize_config().INFERENCE_MODE: q_layout_dgrad = None if "quantize_meta_set" in kwargs: @@ -866,57 +902,88 @@ def _create_set( else: args_x = args_kernel = args_grad = {} - q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x) + q_x = QuantizerFactory.create(1, x_scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x) q_kernel = QuantizerFactory.create( - 1, scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel + 1, kernel_scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel ) q_dgrad = QuantizerFactory.create( - 1, scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad + 1, grad_scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad ) return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad) @staticmethod def create_set( n_quantizer_sets: int = 1, - scaling_mode: ScalingMode = None, + scaling_mode: Optional[ScalingMode] = None, fwd_dtype: jnp.dtype = None, bwd_dtype: jnp.dtype = None, is_2x2x: bool = None, n_groups: int = None, + fp8_recipe: Optional[recipe.Recipe] = None, **kwargs, ) -> tuple[Union[tuple[Quantizer], None]]: """Create one or more sets of quantizers. Args: n_quantizer_sets: Number of quantizer sets to create - scaling_mode: Scaling mode to use, default is QuantizeConfig.SCALING_MODE - fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE - bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE - is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X + scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode + fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE + bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE + is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X n_groups: + fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set. **kwargs: Additional arguments for quantizer initialization Returns: A single quantizer set or tuple of quantizer sets """ - scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE - fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE - bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE + + assert scaling_mode is None or fp8_recipe is None, ( + "Cannot specify both scaling_mode and fp8_recipe when creating a quantizer set. Scaling" + " mode can be specified directly via the scaling_mode parameter or indirectly via" + " recipe. Recipe is preferred as it will support additional recipes in future where" + " scaling mode differs between x, kernel, and grad in the quantizer set." + ) + + if fp8_recipe is not None: + quantize_config = get_quantize_config_class(fp8_recipe)() + x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X) + kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL) + grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD) + elif scaling_mode is not None: + x_scaling_mode = scaling_mode + kernel_scaling_mode = scaling_mode + grad_scaling_mode = scaling_mode + else: + x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X) + kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL) + grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD) + + fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE + bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE if is_2x2x is None: - if scaling_mode.is_1d_block_scaling(): + # TODO(Jeremy): check x, kernel, grad separately for 2x + if x_scaling_mode.is_1d_block_scaling(): is_2x2x = True - elif scaling_mode.is_tensor_scaling(): + elif x_scaling_mode.is_tensor_scaling(): is_2x2x = not is_fp8_gemm_with_all_layouts_supported() else: # NO_SCALING ignores is_2x2x for now is_2x2x = False - is_inference_mode = QuantizeConfig.INFERENCE_MODE + is_inference_mode = get_quantize_config().INFERENCE_MODE assert not is_inference_mode, "Inference mode is not supported yet!" q_set = [] for _ in range(n_quantizer_sets): q_set.append( QuantizerFactory._create_set( - scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs + x_scaling_mode=x_scaling_mode, + kernel_scaling_mode=kernel_scaling_mode, + grad_scaling_mode=grad_scaling_mode, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=is_2x2x, + n_groups=n_groups, + **kwargs, ) ) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index fc4fd1353..e81a614f0 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -166,6 +166,90 @@ def get_shardy_sharding_rules( """ +class NoScalingModeMetadataImpl(ScalingModeMetadataImpl): + """Implementation for no scaling mode. + + This implementation provides metadata for no scaling mode, for using non-quantized higher-precision datatypes such as bf16. + """ + + def get_scale_dtype(self) -> jnp.dtype: + """Get the data type for scale tensors. This is a placeholder and won't be used for higher-precision values that don't have scaling. + + Returns: + The data type used for scale tensors (float32) + """ + return jnp.float32 + + def get_scale_shape( + self, + data_shape: Tuple[int, ...], + is_colwise: bool = False, + is_padded: bool = True, + flatten_axis: int = -1, + ) -> Tuple[int, ...]: + """Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling. + + Args: + data_shape: The shape of the tensor being scaled + is_colwise: Whether the scaling is column-wise + is_padded: Whether to return padded shape + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + + Returns: + The shape for scale tensors - (1,) + """ + del data_shape, is_colwise, is_padded, flatten_axis + return (0,) + + @lru_cache(maxsize=4) + def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: + """Get the quantize layout for the tensor usage. + + Args: + usage: The usage of the tensor + + Returns: + The quantize layout for the tensor usage + """ + return QuantizeLayout.ROWWISE + + def get_grouped_scale_shape( + self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + ) -> Tuple[int]: + """Get the shape for scale tensors in this mode. + + Args: + data_shape: Original shape of the data tensor + is_colwise: Whether to use column-wise scaling + is_padded: Whether to use padded shapes + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + + Returns: + The shape for scale tensors + """ + del data_shape, group_axis, is_colwise + assert isinstance(n_groups, int) + return (n_groups,) + + def get_shardy_sharding_rules( + self, input_rank, unique_var, flatten_axis + ) -> QuantizeShardyRules: + """Sharding rules for the input and (row, col)wise scale tensors. + + Args: + input_rank: The rank of the input tensor (for which we produce the scale tensor) + unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization. + + Returns: + The Shardy rules for the scaling mode + """ + del flatten_axis + input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) + scale_var = BATCHING + unique_var + "_scale_inv" + return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) + + class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): """Implementation for current scaling mode. @@ -396,7 +480,7 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: The quantize layout for the tensor usage """ # If we need to support 1x1x for inference in the future - # if QuantizeConfig.INFERENCE_MODE: + # if get_quantize_config().INFERENCE_MODE: # assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!") # if usage == TensorUsage.LHS: # return QuantizeLayout.ROWWISE @@ -740,5 +824,5 @@ def tree_unflatten(cls, aux_data, _children): ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), # WAR ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(), - ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(), } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 97e127269..dbbac4abc 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -25,6 +25,8 @@ __all__ = [ "TensorUsage", + "AbstractBaseTensor", + "NoScaleTensor", "ScaledTensor", "ScaledTensor1x", "ScaledTensor2x", @@ -34,14 +36,9 @@ ] -@register_pytree_node_class @dataclass -class ScaledTensor(ABC): - """Abstract base class for scaled tensors. - - This class defines the interface for all scaled tensor implementations, - providing methods for dequantization and accessing row/column-wise components. - """ +class AbstractBaseTensor(ABC): + """Abstract base class for all tensor types.""" @classmethod def tree_unflatten(cls, aux_data, children): @@ -93,9 +90,76 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st """ +@dataclass +class AbstractBaseTensor1x(AbstractBaseTensor): + """Abstract base class for single layout tensors.""" + + data: jnp.ndarray + amax: jnp.ndarray + + +@register_pytree_node_class +@dataclass +class NoScaleTensor(AbstractBaseTensor1x): + """Higher-precision tensor.""" + + def __post_init__(self): + assert isinstance(self.data, jnp.ndarray), "NoScaleTensor's data must be a jnp.ndarray." + + def tree_flatten(self): + """Flattens the tensor for JAX tree operations. + + Returns: + A tuple containing (children, aux_data) for tree operations + """ + children = (self.data, self.amax) + aux_data = () + return (children, aux_data) + + @property + def ndim(self): + """Number of dimensions of the underlying array.""" + return self.data.ndim + + def dequantize(self): + """This is a no-op for a higher-precision tensor so this simply returns the tensor's data.""" + return self.data + + def get_tensor(self, usage: TensorUsage): + """Returns the tensor based on the tensor usage.""" + q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage) + assert ( + q_layout == QuantizeLayout.ROWWISE + ), "Only ROWWISE layout is supported for NoScaleTensor" + return self + + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): + """Applies sharding constraints to a tensor based on logical axis names. + + Args: + logical_axis_names: Tuple of logical axis names for sharding + + Returns: + The tensor with applied sharding constraints + """ + if not logical_axis_names: + return self + + data = with_sharding_constraint_by_logical_axes(self.data, logical_axis_names) + + return NoScaleTensor( + data=data, + amax=self.amax, + ) + + +class ScaledTensor(ABC): + """Abstract base class for scaled tensors.""" + + @register_pytree_node_class @dataclass -class ScaledTensor1x(ScaledTensor): +class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): """Single-scale quantized tensor implementation. This class represents a tensor quantized with a single scaling factor, @@ -104,6 +168,7 @@ class ScaledTensor1x(ScaledTensor): Attributes: data: The quantized tensor data scale_inv: The inverse scaling factors + amax: The maximum absolute value of the tensor scaling_mode: The scaling mode used for quantization dq_dtype: The data type for dequantized values _dq_func: The dequantization function @@ -112,7 +177,6 @@ class ScaledTensor1x(ScaledTensor): flatten_axis: The quantization axis for the tensor """ - data: jnp.ndarray scale_inv: jnp.ndarray scaling_mode: ScalingMode dq_dtype: jnp.dtype @@ -152,7 +216,7 @@ def tree_flatten(self): Returns: A tuple containing (children, aux_data) for tree operations """ - children = (self.data, self.scale_inv) + children = (self.data, self.amax, self.scale_inv) aux_data = ( self.scaling_mode, self.dq_dtype, @@ -224,6 +288,7 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st return ScaledTensor1x( data=data, scale_inv=scale_inv, + amax=self.amax, scaling_mode=self.scaling_mode, dq_dtype=self.dq_dtype, _dq_func=self._dq_func, @@ -255,6 +320,7 @@ def __init__( self, data, scale_inv, + amax, group_sizes, scaling_mode, dq_dtype, @@ -270,7 +336,15 @@ def __init__( self.original_shape = original_shape self.group_axis = group_axis super().__init__( - data, scale_inv, scaling_mode, dq_dtype, _dq_func, is_colwise, data_layout, flatten_axis + data=data, + scale_inv=scale_inv, + amax=amax, + scaling_mode=scaling_mode, + dq_dtype=dq_dtype, + _dq_func=_dq_func, + is_colwise=is_colwise, + data_layout=data_layout, + flatten_axis=flatten_axis, ) def __post_init__(self): @@ -308,7 +382,7 @@ def tree_flatten(self): Returns: A tuple containing (children, aux_data) for tree operations """ - children = (self.data, self.scale_inv, self.group_sizes) + children = (self.data, self.scale_inv, self.amax, self.group_sizes) aux_data = ( self.scaling_mode, self.dq_dtype, @@ -327,7 +401,7 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st @register_pytree_node_class @dataclass -class ScaledTensor2x(ScaledTensor): +class ScaledTensor2x(AbstractBaseTensor, ScaledTensor): """Double-scale quantized tensor implementation. This class represents a tensor quantized with both row-wise and column-wise scaling factors. @@ -413,7 +487,8 @@ class ScaledTensorFactory: def create_1x( data, scale_inv, - scaling_mode, + amax=None, + scaling_mode=ScalingMode.NO_SCALING, dq_dtype=jnp.bfloat16, is_colwise=False, data_layout="N", @@ -427,18 +502,22 @@ def create_1x( Args: data: The quantized tensor data scale_inv: The inverse scaling factors + amax: The maximum absolute value of the tensor scaling_mode: The scaling mode for quantization dq_dtype: The data type for dequantized values (default: bfloat16) is_colwise: Whether to use column-wise quantization (default: False) data_layout: The data_layout specification (default: "N") flatten_axis: The quantization axis for the tensor - group_sizes: Arra of ints containing the size of each group (default: None) + group_sizes: Array of ints containing the size of each group (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) Returns: A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided """ + if amax is None: + amax = jnp.empty((1,), dtype=jnp.float32) + dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) if group_sizes is not None: @@ -468,6 +547,7 @@ def create_1x( return GroupedScaledTensor1x( data=data, scale_inv=scale_inv, + amax=amax, scaling_mode=scaling_mode, dq_dtype=dq_dtype, _dq_func=dequantizer.grouped_dequantize, @@ -485,14 +565,15 @@ def create_1x( flatten_axis = data.ndim - flatten_axis return ScaledTensor1x( - data, - scale_inv, - scaling_mode, - dq_dtype, - dequantizer.dequantize, - is_colwise, - data_layout, - flatten_axis, + data=data, + scale_inv=scale_inv, + amax=amax, + scaling_mode=scaling_mode, + dq_dtype=dq_dtype, + _dq_func=dequantizer.dequantize, + is_colwise=is_colwise, + data_layout=data_layout, + flatten_axis=flatten_axis, ) @staticmethod @@ -501,7 +582,8 @@ def create_2x( scale_inv, colwise_data, colwise_scale_inv, - scaling_mode, + amax=None, + scaling_mode=ScalingMode.NO_SCALING, dq_dtype=jnp.bfloat16, data_layout="NN", flatten_axis=-1, @@ -516,6 +598,7 @@ def create_2x( scale_inv: The row-wise inverse scaling factors colwise_data: The column-wise quantized data colwise_scale_inv: The column-wise inverse scaling factors + amax: The maximum absolute value of the tensor scaling_mode: The scaling mode for quantization dq_dtype: The data type for dequantized values (default: bfloat16) data_layout: The data_layout specification (default: "NN") @@ -527,10 +610,14 @@ def create_2x( Returns: A ScaledTensor2x instance """ + if amax is None: + amax = jnp.empty((1,), dtype=jnp.float32) + assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}" rowwise_tensor = ScaledTensorFactory.create_1x( data, scale_inv, + amax, scaling_mode, dq_dtype, is_colwise=False, @@ -543,6 +630,7 @@ def create_2x( colwise_tensor = ScaledTensorFactory.create_1x( colwise_data, colwise_scale_inv, + amax, scaling_mode, dq_dtype, is_colwise=True, @@ -560,7 +648,8 @@ def create( scale_inv: jnp.ndarray, colwise_data: jnp.ndarray, colwise_scale_inv: jnp.ndarray, - scaling_mode: ScalingMode, + amax=None, + scaling_mode: ScalingMode = ScalingMode.NO_SCALING, dq_dtype: jnp.dtype = jnp.bfloat16, data_layout: str = "NN", q_layout: QuantizeLayout = QuantizeLayout.ROWWISE, @@ -594,6 +683,7 @@ def create( scale_inv, colwise_data, colwise_scale_inv, + amax, scaling_mode, dq_dtype, data_layout=data_layout, @@ -608,6 +698,7 @@ def create( return ScaledTensorFactory.create_1x( colwise_data, colwise_scale_inv, + amax, scaling_mode, dq_dtype, is_colwise=is_colwise, @@ -621,6 +712,7 @@ def create( return ScaledTensorFactory.create_1x( data, scale_inv, + amax, scaling_mode, dq_dtype, is_colwise=is_colwise, @@ -645,7 +737,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, . if isinstance(x, GroupedScaledTensor1x): raise NotImplementedError - if isinstance(x, ScaledTensor): + if isinstance(x, AbstractBaseTensor): return x.apply_sharding_constraint_by_logical_axes(logical_axis_names) return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index e59c9de12..339e74e2f 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -9,16 +9,14 @@ parallelism (FSDP). It includes functions for sharding constraints, mesh management, and collective operations. """ -import os from contextlib import contextmanager from dataclasses import dataclass -from enum import Enum from typing import Callable, Optional import warnings -from jax.interpreters import pxla import jax import jax.numpy as jnp -from jax.sharding import PartitionSpec +from jax.interpreters import pxla +from jax.sharding import PartitionSpec, get_abstract_mesh import numpy as np _PXLA_THREAD_RESOURCES = pxla.thread_resources @@ -43,67 +41,84 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh): return mesh.shape[resource], resource +def _validate_mesh_resource_configuration(mesh_resource): + """Validate that the mesh resource configuration is consistent and conflict-free.""" + is_dp_enabled = ( + mesh_resource.dp_resource is not None and get_mesh_axis_size(mesh_resource.dp_resource) > 1 + ) + is_tp_enabled = ( + mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1 + ) + is_tpsp_enabled = ( + mesh_resource.tpsp_resource is not None + and get_mesh_axis_size(mesh_resource.tpsp_resource) > 1 + ) + is_fsdp_enabled = ( + mesh_resource.fsdp_resource is not None + and get_mesh_axis_size(mesh_resource.fsdp_resource) > 1 + ) + + assert not (is_dp_enabled and is_fsdp_enabled), ( + "Data parallelism and full-sharded data parallelism cannot be enabled at the same time." + f" Got dp_resource={mesh_resource.dp_resource} and" + f" fsdp_resource={mesh_resource.fsdp_resource}" + ) + assert not (is_tp_enabled and is_tpsp_enabled), ( + "Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time." + f" Got tp_resource={mesh_resource.tp_resource} and" + f" tpsp_resource={mesh_resource.tpsp_resource}" + ) + + def get_sharding_map_logic_axis_to_mesh_axis(): """ Generate a dict to map logical axes to mesh axes. """ gsr = global_mesh_resource() - IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False))) - - batch_resources = ( - [gsr.fsdp_resource, gsr.dp_resource] - if IS_FSDP_OUTER - else [gsr.dp_resource, gsr.fsdp_resource] - ) - - batch_dim_rule = [] - for resource in batch_resources: - if resource is not None and resource not in batch_dim_rule: - batch_dim_rule.append(resource) - - if len(batch_dim_rule) <= 0: - batch_dim_rule = None - elif len(batch_dim_rule) == 1: - batch_dim_rule = batch_dim_rule[0] - else: - batch_dim_rule = tuple(batch_dim_rule) + is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1 + is_fsdp_enabled = gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1 te_logical_axis_to_mesh_axis = { - BATCH_AXES: batch_dim_rule, + BATCH_AXES: gsr.fsdp_resource if is_fsdp_enabled else gsr.dp_resource, SEQLEN_AXES: None, - SEQLEN_TP_AXES: gsr.tp_resource, + SEQLEN_TP_AXES: gsr.tpsp_resource, SEQLEN_CP_AXES: gsr.cp_resource, - HEAD_AXES: gsr.tp_resource, + HEAD_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource, HIDDEN_AXES: None, - HIDDEN_TP_AXES: gsr.tp_resource, + HIDDEN_TP_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource, JOINED_AXES: None, W_NO_SHARD_AXES: None, W_FSDP_AXES: gsr.fsdp_resource, - W_TP_AXES: gsr.tp_resource, + W_TP_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource, W_JOINED_AXES: None, } return te_logical_axis_to_mesh_axis -def generate_pspec(logical_axis_names): +def _generate_pspec(logical_axis_names): """ - Convert logical axes to PartitionSpec + Convert TransformerEngine logical axes (e.g. BATCH_AXES) to a JAX PartitionSpec. + Note, this method does not support Flax logical axes. + + Args: + logical_axis_names: TransformerEngine logical axes to convert to a JAX PartitionSpec. + Returns: + A JAX PartitionSpec with the mesh axes corresponding to the given TransformerEngine logical axis names """ rules = get_sharding_map_logic_axis_to_mesh_axis() - # mesh_axis_names = [rules[name] for name in logical_axis_names] - mesh_axis_names = [] - for name in logical_axis_names: - axis_name = rules[name] if name in rules else None - mesh_axis_names.append(axis_name) + + mesh_axis_names = [rules.get(name) for name in logical_axis_names] pspec = jax.sharding.PartitionSpec(*mesh_axis_names) return pspec def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): """ - A wrapper function to jax.lax.with_sharding_constraint to - support the case that Mesh is empty. + A wrapper function to jax.lax.with_sharding_constraint + 1. Does nothing if mesh is empty. + 2. If all mesh axes are manual axes, replaces pspec with all Nones. + 3. Otherwise, strips only the manual axes. """ if pspec is None: return x @@ -111,7 +126,14 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh if mesh.empty: return x - return jax.lax.with_sharding_constraint(x, pspec) + + # We want to exclude the axes that already used by shard_map and shard_map + # only sets those in the abstract_mesh, not the physical one + manual_axis_names = get_abstract_mesh().manual_axes + cleaned_axis_names = tuple(name if name not in manual_axis_names else None for name in pspec) + + cleaned_pspec = PartitionSpec(*cleaned_axis_names) + return jax.lax.with_sharding_constraint(x, cleaned_pspec) def with_sharding_constraint_by_logical_axes( @@ -143,7 +165,7 @@ def with_sharding_constraint_by_logical_axes( flax_rules = flax.linen.get_logical_axis_rules() if len(flax_rules) > 0: return flax.linen.with_logical_constraint( - x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT + x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.AXIS_IS_UNSHARDED ) except ImportError: pass @@ -159,7 +181,7 @@ def with_sharding_constraint_by_logical_axes( # If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table assert len(x.shape) == len(logical_axis_names) - pspec = generate_pspec(logical_axis_names) + pspec = _generate_pspec(logical_axis_names) return with_sharding_constraint(x, pspec) @@ -262,6 +284,7 @@ class MeshResource: Attributes: dp_resource: Axis name for data parallelism (batch sharding), default is None tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None + tpsp_resource: Axis name for tensor sequence parallelism (hidden and sequence sharding), default is None fsdp_resource: Axis name for full-sharded data parallelism, default is None pp_resource: Axis name for pipeline parallelism (layer sharding), default is None cp_resource: Axis name for context parallelism (sequence sharding), default is None @@ -269,12 +292,13 @@ class MeshResource: dp_resource: str = None tp_resource: str = None + tpsp_resource: str = None fsdp_resource: str = None pp_resource: str = None cp_resource: str = None -_GLOBAL_MESH_RESOURCE = MeshResource() +_GLOBAL_MESH_RESOURCE = None @contextmanager @@ -302,6 +326,12 @@ def global_mesh_resource() -> MeshResource: Returns: The current MeshResource instance """ + assert _GLOBAL_MESH_RESOURCE is not None, ( + "Global mesh resource is not set. Please set the MeshResource via a global_shard_guard" + " context. If you are not using multiple GPUs, you can use an empty MeshResource by" + " wrapping your program in 'with global_shard_guard(MeshResource()):'" + ) + _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) return _GLOBAL_MESH_RESOURCE @@ -334,73 +364,3 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes if axis != global_mesh_resource().pp_resource: x = lax_paral_op(x, jax.lax.pmax, axis, mesh) return x - - -# Deprecating Items --------------------------------------------------------------- -ShardingResource = MeshResource - -global_shard_resource = global_mesh_resource - - -class MajorShardingType(Enum): - """Enumeration of major sharding types for distributed training. - - This enum defines the basic sharding patterns available for distributed - training. Note that this class is deprecated and will be removed in the future. - - Values: - SINGLE: Single process training - DP: Data parallel training - TP: Standard tensor parallel training - DPTP: Data and standard tensor parallel training - """ - - SINGLE = 0 - DP = 1 - TP = 2 - DPTP = 3 - - -class ShardingType(Enum): - """Enumeration of detailed sharding types for distributed training. - - This enum defines specific sharding patterns for distributed training, - including combinations of data parallelism and different tensor parallelism - strategies. Note that this class is deprecated and will be removed in the future. - - Values: - SINGLE: No sharding - DP: Sharding along data parallelism - TP_COL: Sharding along column-split tensor parallelism - TP_ROW: Sharding along row-split tensor parallelism - DP_TP_COL: Sharding along data and column-split tensor parallelism - DP_TP_ROW: Sharding along data and row-split tensor parallelism - """ - - SINGLE = (MajorShardingType.SINGLE, "single") - DP = (MajorShardingType.DP, "dp") - TP_COL = (MajorShardingType.TP, "tp_col") - TP_ROW = (MajorShardingType.TP, "tp_row") - DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col") - DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row") - - -def get_non_contracting_logical_axes( - ndim, logical_axes: tuple[Optional[str]], contracting_dims -) -> tuple[Optional[str]]: - """Get logical axes for non-contracting dimensions. - - Args: - ndim: Number of dimensions in the tensor. - logical_axes: Tuple of logical axes for each dimension. - contracting_dims: Set of dimensions that are being contracted. - - Returns: - Tuple of logical axes for non-contracting dimensions. - """ - assert logical_axes is not None, "Logical axes must be a tuple and cannot be None." - assert len(logical_axes) == ndim, "Logical axes must match the number of dimensions." - - non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims] - non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims) - return non_contracting_logical_axes diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 2e86a77a5..3bdbe4089 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -33,6 +33,7 @@ def torch_version() -> tuple[int, ...]: from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding from transformer_engine.pytorch.module import initialize_ub from transformer_engine.pytorch.module import destroy_ub +from transformer_engine.pytorch.module import UserBufferQuantizationMode from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.attention import InferenceParams diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 9fc2342e7..7eedd688f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -6,7 +6,7 @@ """Context Parallelism.""" import os -from typing import List, Union +from typing import List, Union, Tuple import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine_torch as tex @@ -361,7 +361,7 @@ def get_fa_args( max_seqlen_q, max_seqlen_kv, *[None] - * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + * 9, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale ] return [ *[None] @@ -369,7 +369,7 @@ def get_fa_args( max_seqlen_q, max_seqlen_kv, *[None] - * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + * 9, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale ] if qkv_format == "thd": return [ @@ -832,6 +832,19 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -841,19 +854,10 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=True, **fa_forward_kwargs, @@ -988,6 +992,22 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + else: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1004,19 +1024,10 @@ def forward( elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=False, **fa_forward_kwargs, @@ -1147,6 +1158,19 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1163,19 +1187,10 @@ def forward( elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=False, **fa_forward_kwargs, @@ -1272,6 +1287,19 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1281,19 +1309,10 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q, - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=False, **fa_forward_kwargs, @@ -1868,7 +1887,27 @@ def backward(ctx, dout): dv_ = dv_._data else: dq_ = torch.empty_like(q_) - dkv_ = torch.empty_like(kv_) + if ctx.enable_mla: + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + dkv_ = torch.empty_like(kv_) + dk_ = ( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ) + dv_ = ( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ) fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -1878,16 +1917,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ), - dv=( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ), + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or ( fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus @@ -1898,12 +1929,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = 0 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + k_part, + v_part, out_, softmax_lse, *fa_backward_args_thd, @@ -2019,7 +2049,29 @@ def backward(ctx, dout): dv_ = dv_._data else: dq_ = torch.empty_like(q_) - dkv_ = torch.empty_like(kv_) + if ctx.enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + dkv_ = torch.empty_like(kv_) + dk_ = ( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ) + dv_ = ( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ) fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -2029,16 +2081,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv // 2, dq=dq_, - dk=( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ), - dv=( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ), + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or ( fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus @@ -2049,12 +2093,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + k_part, + v_part, out_, softmax_lse, *fa_backward_args_thd, @@ -2163,7 +2206,27 @@ def backward(ctx, dout): dv_ = dv_._data else: dq_ = torch.empty_like(q_) - dkv_ = torch.empty_like(kv_) + if ctx.enable_mla: + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + dkv_ = torch.empty_like(kv_) + dk_ = ( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ) + dv_ = ( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ) fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -2173,16 +2236,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q // 2, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ), - dv=( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ), + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or ( fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus @@ -2193,12 +2248,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + k_part, + v_part, out_, softmax_lse_, *fa_backward_args_thd, @@ -2270,7 +2324,15 @@ def backward(ctx, dout): else: dq_ = torch.empty_like(q) - dkv_ = torch.empty_like(kv) + if ctx.enable_mla: + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] + v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] + dkv_ = torch.empty_like(kv) + dk_ = dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0] + dv_ = dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1] fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -2280,8 +2342,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], - dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) @@ -2290,12 +2352,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout, q, - kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], - kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], + k_part, + v_part, out, softmax_lse, *fa_backward_args_thd, @@ -3930,3 +3991,212 @@ def attn_forward_func_with_cp( raise ValueError(f"Unsupported communication type: {cp_comm_type}!") return out + + +def pad_thd_sequences_for_cp( + input_ids: torch.Tensor, + labels: torch.Tensor, + cu_seqlens: torch.Tensor, + divisibility_factor: int, + padding_token_id: int = 0, + padding_label_id: int = -100, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Pads sequences to be divisible by the divisibility factor. + + Args: + input_ids: Tensor of shape (1, N) or (N,) containing concatenated sequences + labels: Tensor of shape (1, N) or (N,) containing labels for each token + cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths + divisibility_factor: Each sequence length must be divisible by this factor + padding_token_id: Token ID to use for padding (default: 0) + padding_label_id: Label ID to use for padding (default: -100) + + Returns: + Tuple of: + - input_ids_padded: Padded input_ids tensor + - labels_padded: Padded labels tensor + - cu_seqlens_padded: Cumulative sequence lengths accounting for padding + """ + # Flatten input_ids and labels if needed + if input_ids.dim() == 2: + input_ids = input_ids.squeeze(0) + if labels.dim() == 2: + labels = labels.squeeze(0) + + # Compute the sequence lengths from cu_seqlens + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + # List: amount of padding needed for each sequence (make length a multiple of divisibility_factor) + padding_amounts = [ + ((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor + - l.item() + for l in seqlens + ] + + # Extract sequences and labels for each batch item + batch_sequences = [ + input_ids[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]) + ] + batch_labels = [ + labels[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]) + ] + + # Pad sequences and labels to required length + input_ids_padded = torch.cat( + [ + ( + torch.cat([seq, torch.full((pad,), padding_token_id, dtype=seq.dtype)]) + if pad > 0 + else seq + ) + for seq, pad in zip(batch_sequences, padding_amounts) + ] + ) + labels_padded = torch.cat( + [ + ( + torch.cat([seq, torch.full((pad,), padding_label_id, dtype=seq.dtype)]) + if pad > 0 + else seq + ) + for seq, pad in zip(batch_labels, padding_amounts) + ] + ) + + # Compute cumulative padded sequence lengths, starting from 0 + padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) + cu_seqlens_padded = torch.cumsum( + torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), padded_lengths]), dim=0 + ) + + return input_ids_padded, labels_padded, cu_seqlens_padded + + +def generate_positional_ids_for_cp( + cu_seqlens: torch.Tensor, + divisibility_factor: int, + dtype: torch.dtype = torch.long, +) -> torch.Tensor: + """Generate positional IDs for sequences padded to be divisible by divisibility_factor. + + Args: + cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths + divisibility_factor: Each sequence length must be divisible by this factor + dtype: Data type for the generated positional IDs (default: torch.long) + + Returns: + Generated positional_ids tensor where each sequence starts from 0 and continues through padding + """ + # Compute the sequence lengths from cu_seqlens + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + # List: amount of padding needed for each sequence + padding_amounts = [ + ((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor + - l.item() + for l in seqlens + ] + + # Generate positional IDs for each padded sequence (each starts from 0) + padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) + positional_ids = torch.cat( + [torch.arange(0, int(length), dtype=dtype) for length in padded_lengths] + ) + + return positional_ids + + +def get_batch_on_this_cp_rank( + cu_seqlens_padded: torch.Tensor, + input_ids_padded: torch.Tensor, + labels_padded: torch.Tensor, + position_ids_padded: torch.Tensor, + cp_group: torch.distributed.ProcessGroup = None, + qvk_format: str = "thd", +): + """Slice batch input along sequence dimension into multiple chunks for THD format. + + This function is inteded for use in self attention. It will not work for cross attention because + it does not handle the case where the sequence length of the query and key are different. + + Which are parallelized across GPUs in a context parallel group. + This version works with variable-length sequences using cumulative sequence lengths. + """ + if qvk_format not in ["thd", "bshd", "sbhd"]: + raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + if qvk_format == "thd": + # Get context parallel size and rank + cp_size = torch.distributed.get_world_size(group=cp_group) + if cp_size > 1: + cp_rank = torch.distributed.get_rank(group=cp_group) + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_size + slice_sizes = ( + cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + ) // total_slices_of_any_sequence + + # Process each tensor directly instead of using keys_to_change loop + def process_tensor(val): + if val is None: + return val + # Determine which dimension is the sequence dimension + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + if isinstance(cu_seqlens_padded[-1], torch.Tensor): + seq_len_val = cu_seqlens_padded[-1].item() + else: + seq_len_val = cu_seqlens_padded[-1] + + # Handle 1D tensors (like position_ids that don't have batch dimension) + if val.ndim == 1: + if val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "1D tensor shape doesn't match expected sequence length. Make sure the" + " inputs are in THD format and padded correctly." + ) + elif val.ndim >= 2: + if val.shape[1] == seq_len_val: + current_seq_dim = 1 + elif val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "Make sure the inputs are in THD format and padded correctly." + ) + else: + raise ValueError("Tensor must be at least 1D") + + # On this particular rank, for each sequence, get two slices, one from the beginning + # and one from the end. + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(current_seq_dim, torch.cat(cp_rank_slices)) + + # Process each tensor directly + input_ids_padded = process_tensor(input_ids_padded) + labels_padded = process_tensor(labels_padded) + position_ids_padded = process_tensor(position_ids_padded) + else: + raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!") + + return input_ids_padded, labels_padded, position_ids_padded diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 893e2d228..b35b87a83 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -630,7 +630,7 @@ def forward( If true, there are padding tokens between individual sequences in a packed batch. """ - with self.prepare_forward( + with torch.cuda.device(query_layer.device), self.prepare_forward( query_layer, num_gemms=3, allow_non_contiguous=True, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index e03543d40..1677689c1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -129,10 +129,10 @@ class FlashAttentionUtils: # Please follow these instructions to install FA3 v3_installation_steps = """\ (1) git clone https://github.com/Dao-AILab/flash-attention.git -(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install +(2) cd flash-attention/ && git checkout 3ba6f82 && git submodule update --init && cd hopper/ && python setup.py install (3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` (4) mkdir -p $python_path/flash_attn_3 -(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py""" +(5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py""" v3_warning_printed = False @staticmethod @@ -439,8 +439,10 @@ def get_attention_backend( # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - if device_compute_capability == (8, 9) and cudnn_version < (9, 12, 0): - logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.12") + # Temporarily disabling fused attention for kv caching for sm89 irrespective of cuDNN version + # until the cuDNN bug is resolved + if device_compute_capability == (8, 9): + logger.debug("Disabling FusedAttention for KV caching for sm89") use_fused_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") @@ -482,11 +484,10 @@ def get_attention_backend( # Filter: Head dimension if head_dim_qk != head_dim_v: - if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( - use_flash_attention_3 and FlashAttentionUtils.v3_is_installed - ): - logger.debug("Disabling FlashAttention as it does not support MLA.") - use_flash_attention = False + if use_flash_attention_2 and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention 2 as it does not support MLA.") + use_flash_attention_2 = False + qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") if use_fused_attention and qkv_layout_group != "hd_hd_hd": logger.debug( @@ -514,10 +515,41 @@ def get_attention_backend( ".".join([str(i) for i in device_compute_capability]), ) use_flash_attention_2 = False - if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128): - if FlashAttentionUtils.v3_is_installed: - logger.debug("Disabling FlashAttention 3 for head_dim > 128") - use_flash_attention_3 = False + if use_flash_attention_3: + + def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype): + if head_dim_qk > 256 or num_heads % num_gqa_groups != 0: + return False + if head_dim_qk != head_dim_v: + cond1 = 128 < head_dim_qk <= 192 + cond2 = 96 < head_dim_v <= 128 + cond3 = head_dim_qk <= 64 and head_dim_v <= 512 + if not ((cond1 and cond2) or cond3): + return False + if head_dim_v > 256 and qkv_dtype not in (torch.bfloat16, torch.float16): + return False + return True + + if not _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype): + if FlashAttentionUtils.v3_is_installed: + logger.debug( + "Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, " + "head_dim_qk, head_dim_v or qkv_dtype. " + "Supported: head_dim_qk <= 256, and num_heads %% num_gqa_groups = 0, and " + "if head_dim_qk is different from head_dim_v, then " + "(head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or " + "(head_dim_qk <= 64 and head_dim_v <= 512), and " + "if head_dim_qk is different from head_dim_v and head_dim_v > 256, then " + "qkv_dtype requires fp16 and bf16 data type. " + "Found: num_heads = %s, num_gqa_groups = %s, " + "head_dim_qk = %s, head_dim_v = %s and qkv_dtype = %s.", + num_heads, + num_gqa_groups, + head_dim_qk, + head_dim_v, + qkv_dtype, + ) + use_flash_attention_3 = False # Filter: QKV layout if qkv_format == "thd": @@ -615,7 +647,7 @@ def get_attention_backend( " bias for THD format" ) use_fused_attention = False - elif fp8 and head_dim_qk != head_dim_v: + elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v: logger.debug( "Disabling FusedAttention as it does not support context parallelism with FP8" " MLA attention" @@ -830,7 +862,7 @@ def get_attention_backend( # flash-attn >=2.4.1 | yes # FusedAttention | # sub-backend 0 | yes - # sub-backend 1 | workspace optimization path and sm90+: yes; + # sub-backend 1 | workspace optimization path and sm90: yes; # | otherwise: no # sub-backend 2 | no # UnfusedDotProductAttention | yes @@ -846,8 +878,9 @@ def get_attention_backend( use_flash_attention_2 = False if use_fused_attention and deterministic and (not IS_HIP_EXTENSION): if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: - logger.debug("Disabling FusedAttention for determinism reasons") + logger.debug("Disabling FusedAttention for determinism reasons with FP8") use_fused_attention = False + fused_attention_backend = None if ( fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] and is_training @@ -857,8 +890,13 @@ def get_attention_backend( or cudnn_version < (8, 9, 5) ) ): - logger.debug("Disabling FusedAttention for determinism reasons") + logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") use_fused_attention = False + fused_attention_backend = None + if is_training and device_compute_capability >= (10, 0) and cudnn_version <= (9, 14, 0): + logger.debug("Disabling FusedAttention for determinism reasons on Blackwell") + use_fused_attention = False + fused_attention_backend = None # TODO: remove the filtering after ck team tells us how to enable more deterministic bwd kernels if use_fused_attention and deterministic and IS_HIP_EXTENSION: if ( @@ -867,6 +905,7 @@ def get_attention_backend( ): logger.debug("Disabling FusedAttention for determinism reasons") use_fused_attention = False + fused_attention_backend = None #TODO: switch to AOTriton when supported # use_flash_attention may have been set above use_flash_attention_2 = use_flash_attention and use_flash_attention_2 use_flash_attention_3 = use_flash_attention and use_flash_attention_3 diff --git a/transformer_engine/pytorch/attention/inference.py b/transformer_engine/pytorch/attention/inference.py index 8d5417a45..f0ef8d0bd 100644 --- a/transformer_engine/pytorch/attention/inference.py +++ b/transformer_engine/pytorch/attention/inference.py @@ -215,6 +215,17 @@ def __init__( device=torch.cuda.current_device(), ) + # This internal buffer holds the running length of each + # unfinished sequence in the batch and is updated in `pre_step()` + # method. One use of this buffer is applying RoPE to q and k tensors + # during inference by slicing ROPE Embeddings according to the + # current sequence length window. + self.pre_step_seqlens = torch.zeros( + self.max_batch_size, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + def reset(self): """Reset InferenceParams state""" self.sequences = OrderedDict() @@ -266,6 +277,15 @@ def pre_step( for k, v in self.sequences.items(): self.sequences_pre_step[k] = v - step_dict[k] + pre_step_seqlens_temp = torch.Tensor(list(self.sequences_pre_step.values())).to( + dtype=torch.int32, device="cpu" + ) + + # Copy the pre-step seqlens to the device in CUDA Graphs safe manner. + self.pre_step_seqlens[: len(pre_step_seqlens_temp)].copy_( + pre_step_seqlens_temp, non_blocking=False + ) + seqlens_q = list(step_dict.values()) cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)] cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size) @@ -280,9 +300,7 @@ def pre_step( def get_seqlens_pre_step(self): """Get cached sequence lengths before the stepping""" - return torch.Tensor(list(self.sequences_pre_step.values())).to( - dtype=torch.int32, device="cpu" - ) + return self.pre_step_seqlens def convert_paged_to_nonpaged(self, layer_number: int): """ @@ -458,14 +476,14 @@ def pre_step( finished_seqs = self.sequences.keys() - unfinished_seqs unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] - self.batch_indices.copy_( + self.batch_indices.data[:].copy_( torch.Tensor( ( unfinished_indices + finished_indices + list(range(prev_batch_size, self.max_batch_size)) ) - ).to(dtype=torch.int32, device="cpu") + ) ) # Advance unfinished sequences diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 142044240..5fd16bf1a 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -11,7 +11,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule -from transformer_engine.pytorch.module import LayerNormLinear, Linear +from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization from transformer_engine.pytorch.utils import ( SplitAlongDim, @@ -175,14 +175,23 @@ class MultiheadAttention(torch.nn.Module): parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument `fuse_wgrad_accumulation`. - use_qk_norm: bool, default = 'False' - if set to `True`, L2 normalization is applied to query and key tensors - after RoPE (if applicable) but before attention computation. - This follows the Llama4 approach for QK normalization to improve - training stability and model performance. + qk_norm_type: Optional[str], default = None + type of normalization to apply to query and key tensors. + Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied. + When 'L2Normalization', L2 normalization is applied to query and key tensors. + When 'RMSNorm', RMS normalization is applied to query and key tensors. + When 'LayerNorm', layer normalization is applied to query and key tensors. + Normalization is applied after RoPE (if applicable) but before attention computation + when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach + for QK normalization to improve training stability and model performance. qk_norm_eps: float, default = 1e-6 - epsilon value for L2 normalization of query and key tensors. - Only used when `use_qk_norm` is True. + epsilon value for normalization of query and key tensors. + Only used when `qk_norm_type` is not None. + qk_norm_before_rope: bool, default = `False` + if set to `True`, query and key normalization is applied before rotary position + embedding. When `False` (default), normalization is applied after RoPE. + This parameter allows supporting different architectural variants that apply + QK normalization at different points. seq_length: Optional[int], default = `None` sequence length of input samples. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for @@ -231,8 +240,9 @@ def __init__( device: Union[torch.device, str] = "cuda", qkv_format: str = "sbhd", name: str = None, - use_qk_norm: bool = False, + qk_norm_type: Optional[str] = None, qk_norm_eps: float = 1e-6, + qk_norm_before_rope: bool = False, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, ) -> None: @@ -264,6 +274,7 @@ def __init__( qkv_weight_interleaved = False self.qkv_weight_interleaved = qkv_weight_interleaved self.rotary_pos_interleaved = rotary_pos_interleaved + self.qk_norm_before_rope = qk_norm_before_rope assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" if layer_number is not None: @@ -288,7 +299,6 @@ def __init__( self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.name = name - self.use_qk_norm = use_qk_norm common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, @@ -300,13 +310,9 @@ def __init__( "device": device, } - # Initialize L2 normalization modules for query and key if enabled - if self.use_qk_norm: - self.qk_norm = L2Normalization( - eps=qk_norm_eps, - seq_length=seq_length, - micro_batch_size=micro_batch_size, - ) + self.q_norm, self.k_norm = self._create_qk_norm_modules( + qk_norm_type, qk_norm_eps, device, seq_length, micro_batch_size + ) qkv_parallel_mode = "column" if set_parallel_mode else None @@ -427,6 +433,78 @@ def __init__( **common_gemm_kwargs, ) + def _create_qk_norm_modules( + self, + qk_norm_type: Optional[str], + qk_norm_eps: float, + device: Union[torch.device, str], + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, + ) -> Tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]]: + """ + Create query and key normalization modules based on the specified normalization type. + + Parameters + ---------- + qk_norm_type : Optional[str] + Type of normalization to apply. Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm' + qk_norm_eps : float + Epsilon value for numerical stability + device : Union[torch.device, str] + Device to place the normalization modules on + seq_length : Optional[int], default = None + Sequence length for L2Normalization optimization + micro_batch_size : Optional[int], default = None + Micro batch size for L2Normalization optimization + + Returns + ------- + Tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]] + Query and key normalization modules (q_norm, k_norm) + """ + if qk_norm_type is None: + return None, None + + if qk_norm_type == "L2Normalization": + l2_norm = L2Normalization( + eps=qk_norm_eps, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + ) + # L2Normalization is parameter-free, so we can share the same instance + return l2_norm, l2_norm + + if qk_norm_type == "RMSNorm": + q_norm = RMSNorm( + normalized_shape=self.hidden_size_per_attention_head, + eps=qk_norm_eps, + device=device, + ) + k_norm = RMSNorm( + normalized_shape=self.hidden_size_per_attention_head, + eps=qk_norm_eps, + device=device, + ) + return q_norm, k_norm + + if qk_norm_type == "LayerNorm": + q_norm = LayerNorm( + normalized_shape=self.hidden_size_per_attention_head, + eps=qk_norm_eps, + device=device, + ) + k_norm = LayerNorm( + normalized_shape=self.hidden_size_per_attention_head, + eps=qk_norm_eps, + device=device, + ) + return q_norm, k_norm + + raise ValueError( + f"Unsupported QK norm type: {qk_norm_type}. " + "Supported types: ['L2Normalization', 'RMSNorm', 'LayerNorm']" + ) + def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ Set the tensor parallel group for the given @@ -789,6 +867,14 @@ def forward( ) query_layer = query_layer.view(*new_tensor_shape) + # =========================== + # Apply normalization to query and key tensors (before RoPE if configured) + # =========================== + + if self.q_norm is not None and self.qk_norm_before_rope: + query_layer = self.q_norm(query_layer) + key_layer = self.k_norm(key_layer) + # ====================================================== # Apply relative positional encoding (rotary embedding) # ====================================================== @@ -803,32 +889,28 @@ def forward( q_pos_emb, k_pos_emb = rotary_pos_emb - # adjust key and value for inference - if inference_params is not None: - if self.qkv_format == "sbhd": - sequence_length = key_layer.size(0) - elif self.qkv_format == "bshd": - sequence_length = key_layer.size(1) - else: - raise ValueError( - f"qkv_format={self.qkv_format} not supported for KV caching and RoPE." - ) - - sequence_start = inference_params.get_seqlens_pre_step() - # sequence_start = inference_params.seqlens[0] - sequence_end = sequence_start + sequence_length + # Applyig RoPE for inference needs start positions of sequences + # for each iteration. + sequence_start_positions = ( + inference_params.get_seqlens_pre_step() if inference_params is not None else None + ) - q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] - k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] + if pad_between_seqs: + rotary_pos_cu_seq_lens_q = cu_seqlens_q_padded + rotary_pos_cu_seq_lens_kv = cu_seqlens_kv_padded + else: + rotary_pos_cu_seq_lens_q = cu_seqlens_q + rotary_pos_cu_seq_lens_kv = cu_seqlens_kv query_layer = apply_rotary_pos_emb( query_layer, q_pos_emb, self.qkv_format, fused=True, - cu_seqlens=cu_seqlens_q, + cu_seqlens=rotary_pos_cu_seq_lens_q, cp_size=self.cp_size, cp_rank=self.cp_rank, + start_positions=sequence_start_positions, interleaved=self.rotary_pos_interleaved, ) key_layer = apply_rotary_pos_emb( @@ -836,19 +918,20 @@ def forward( k_pos_emb, self.qkv_format, fused=True, - cu_seqlens=cu_seqlens_kv, + cu_seqlens=rotary_pos_cu_seq_lens_kv, cp_size=self.cp_size, cp_rank=self.cp_rank, + start_positions=sequence_start_positions, interleaved=self.rotary_pos_interleaved, ) # =========================== - # Apply L2 normalization to query and key tensors + # Apply normalization to query and key tensors (after RoPE if not applied before) # =========================== - if self.use_qk_norm: - query_layer = self.qk_norm(query_layer) - key_layer = self.qk_norm(key_layer) + if self.q_norm is not None and not self.qk_norm_before_rope: + query_layer = self.q_norm(query_layer) + key_layer = self.k_norm(key_layer) # =========================== # Core attention computation diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 60685a31d..139381f2d 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -5,14 +5,14 @@ """ Rotary Position Embedding implementation of different types along with helper functions """ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, List import torch import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat -__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"] +__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] class RotaryPositionEmbedding(torch.nn.Module): @@ -170,6 +170,86 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], return grad_input, None, None, None, None, None, None, None +class FusedQKVRoPEFunc(torch.autograd.Function): + """ + Function for FusedQKVRoPE + + This implementation accepts combined QKV tensor in `bshd` or `sbhd` format. Q and K RoPE tensors are the additional required inputs. + The RoPE tensors should be of shape (s, 1, 1, d). It produces 3 outputs: Q, K after RoPE, V is the same as input. + """ + + @staticmethod + def forward( + ctx, + qkv: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + start_positions: Union[torch.Tensor, None] = None, + tensor_format: str = "sbhd", + interleaved: bool = False, + cp_size: int = 1, + cp_rank: int = 0, + ) -> torch.Tensor: + """Fused RoPE forward.""" + + if q_freqs.dtype != torch.float32: + q_freqs = q_freqs.float() + if k_freqs.dtype != torch.float32: + k_freqs = k_freqs.float() + assert tensor_format in ( + "sbhd", + "bshd", + ), f"Unsupported tensor_format: {tensor_format}." + assert qkv.is_contiguous(), "QKV Tensor should be contiguous." + assert q_freqs.is_contiguous(), "q_freqs Tensor should be contiguous." + assert k_freqs.is_contiguous(), "k_freqs Tensor should be contiguous." + output = tex.fused_qkv_rope_forward( + qkv, + q_freqs, + k_freqs, + start_positions, + qkv_split_arg_list, + QKVFormat[tensor_format], + interleaved, + cp_size, + cp_rank, + ) + ctx.save_for_backward(q_freqs, k_freqs) + ctx.tensor_format = tensor_format + ctx.qkv_split_arg_list = qkv_split_arg_list + ctx.cp_size = cp_size + ctx.cp_rank = cp_rank + ctx.interleaved = interleaved + return output + + @staticmethod + def backward( + ctx, grad_output_q: torch.Tensor, grad_output_k: torch.Tensor, grad_output_v: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused RoPE backward.""" + q_freqs, k_freqs = ctx.saved_tensors + + grad_output_q = grad_output_q.contiguous() + grad_output_k = grad_output_k.contiguous() + grad_output_v = grad_output_v.contiguous() + + grad_input = tex.fused_qkv_rope_backward( + grad_output_q, + grad_output_k, + grad_output_v, + q_freqs, + k_freqs, + ctx.qkv_split_arg_list, + QKVFormat[ctx.tensor_format], + ctx.interleaved, + ctx.cp_size, + ctx.cp_rank, + ) + + return grad_input, None, None, None, None, None, None, None, None + + def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: """Change sign so the last dimension becomes [-odd, +even] @@ -393,3 +473,82 @@ def apply_rotary_pos_emb( tensor_format, interleaved=interleaved, ) + + +def apply_fused_qkv_rotary_pos_emb( + qkv: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + tensor_format: str = "sbhd", + start_positions: Union[torch.Tensor, None] = None, + interleaved: bool = False, + cu_seqlens: Union[torch.Tensor, None] = None, # pylint: disable=unused-argument + cp_size: int = 1, + cp_rank: int = 0, +) -> torch.Tensor: + """ + Apply rotary positional embedding tensor to the input qkv tensor. + + Support matrix: + Fused: + Training: + qkv_formats: "bshd", "sbhd" + context parallel: yes + start_positions: no + interleaving: yes + Inference: + qkv_formats: "bshd", "sbhd" + context parallelism: no + start_positions: yes + interleaving: yes + + Parameters + ---------- + qkv: torch.Tensor + Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which + rotary positional embedding will be applied. This tensor has q, k, v concatenated + along the last dimension. + q_freqs: torch.Tensor + Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + k_freqs: torch.Tensor + Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + qkv_split_arg_list: List[int] + List of integers that specify the split of the qkv tensor. The list should have 3 elements, + the first element is the number of elements in the q tensor, the second element is the number + of elements in the k tensor, and the third element is the number of elements in the v tensor. + The sum of the elements in the list should be equal to the last dimension of the qkv tensor. + start_positions: torch.Tensor, default = None. + Tokens in a sequence `i` should be applied with position encoding offset by + `start_positions[i]`. If `start_positions=None`, there's no offset. + tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' + is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is + of shape `[seq, bs, ...]`. + interleaved: bool, default = False + Whether to use interleaved rotary position embedding. + cp_size: int, default = 1. + Context parallel world size. + cp_rank: int, default = 0. + Context parallel rank. + """ + + # `start_positions` is only supported for `cp_size=1` and inference. + assert not ( + cp_size > 1 and start_positions is not None + ), """start_positions != None with CP SIZE > 1 is not supported!""" + + assert tensor_format != "thd", "'thd' tensor_format not supported currently." + + return FusedQKVRoPEFunc.apply( + qkv, + q_freqs, + k_freqs, + qkv_split_arg_list, + start_positions, + tensor_format, + interleaved, + cp_size, + cp_rank, + ) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 9f3921d36..e4f4e619f 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -21,6 +21,15 @@ ] +def validate_gemm_scale(scale: Optional[float], required: bool) -> float: + """Validate whether a GEMM scaling factor is consistent with its usage""" + if required: + return scale if scale is not None else 1.0 + if scale not in (0.0, None): + raise ValueError("scale must be zero") + return 0.0 + + def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -29,6 +38,8 @@ def general_gemm( quantization_params: Optional[Quantizer] = None, gelu: bool = False, gelu_in: torch.Tensor = None, + alpha: float = 1.0, + beta: Optional[float] = None, accumulate: bool = False, layout: str = "TN", out: Optional[torch.Tensor] = None, @@ -47,6 +58,9 @@ def general_gemm( transb = layout[1] == "T" # assert quantization_params is None, "FP8 output not supported yet" + alpha = validate_gemm_scale(alpha, True) + beta = validate_gemm_scale(beta, accumulate) + if ub_type is not None: assert ub is not None, ( f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires" @@ -108,6 +122,8 @@ def general_gemm( "comm_type": ub_type, "extra_output": extra_output, "bulk_overlap": bulk_overlap, + "alpha": alpha, + "beta": beta, } out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 1c03e3d37..179c80a65 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -431,7 +431,7 @@ def tensor_pop(self, tensor_tag, **kwargs): tensor = self.fp8_tensor_object_map.pop(tensor_tag) if self.double_buffering: - tensor.do_not_clear = True + tensor._do_not_clear = True self.tensor_tag_to_buf.pop(tensor_tag, None) # the tensor should have been copied back in on_group_commit_backward() @@ -551,26 +551,48 @@ def bulk_reload_group(self, group_to_reload): buffer_idx = 0 double_buffer_idx = group_to_reload % 2 + main_stream = torch.cuda.current_stream() + with torch.cuda.stream(self.h2d_stream): # move back tensors for tensor_label, state in self.tensor_tag_to_state.items(): group_id, _ = tensor_label if group_id == group_to_reload: + if isinstance(state, tuple): + if self.double_buffering: + reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx] + else: + with torch.cuda.stream(main_stream): + reload_buffer = torch.empty_like( + state[1], device=torch.cuda.current_device() + ) + recovered_tensor = SynchronizedGroupOffloadHandler.reload( - state, True, self.reload_double_buffer[double_buffer_idx][buffer_idx] + state, True, reload_buffer ) buffer_idx = buffer_idx + 1 self.tensor_tag_to_state[tensor_label] = recovered_tensor elif isinstance(state, list): tensor_list = [] for state_tuple in state: + if isinstance(state_tuple, tuple): + if self.double_buffering: + reload_buffer = self.reload_double_buffer[double_buffer_idx][ + buffer_idx + ] + else: + with torch.cuda.stream(main_stream): + reload_buffer = torch.empty_like( + state_tuple[1], device=torch.cuda.current_device() + ) + tensor_list.append( SynchronizedGroupOffloadHandler.reload( state_tuple, True, - self.reload_double_buffer[double_buffer_idx][buffer_idx], + reload_buffer, ) ) buffer_idx = buffer_idx + 1 diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py index 75b5de37b..076dbec0d 100644 --- a/transformer_engine/pytorch/cross_entropy.py +++ b/transformer_engine/pytorch/cross_entropy.py @@ -29,6 +29,7 @@ def forward( reduce_loss=False, dist_process_group=None, ignore_idx=-100, + is_cg_capturable=False, ): """ The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each @@ -47,10 +48,16 @@ def forward( tensor: The computed loss. """ loss, _input = triton_cross_entropy.cross_entropy_forward( - _input, target, label_smoothing, reduce_loss, dist_process_group, ignore_idx + _input, + target, + label_smoothing, + reduce_loss, + dist_process_group, + ignore_idx, ) ctx.save_for_backward(_input.detach()) + ctx.is_cg_capturable = is_cg_capturable return loss @staticmethod @@ -66,13 +73,17 @@ def backward(ctx, grad_output): tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. """ (_input,) = ctx.saved_tensors - _input = triton_cross_entropy.cross_entropy_backward(_input, grad_output) + _input = triton_cross_entropy.cross_entropy_backward( + _input, grad_output, ctx.is_cg_capturable + ) return ( _input, None, None, None, None, + None, + None, ) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 8e47cf60c..aa6602401 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -18,7 +18,7 @@ namespace transformer_engine::pytorch { -std::vector getTensorShape(at::Tensor t) { +std::vector getTensorShape(const at::Tensor& t) { std::vector shape; for (auto s : t.sizes()) { shape.push_back(s); @@ -292,7 +292,7 @@ std::vector convertShape(const NVTEShape& shape) { return std::vector(shape.data, shape.data + shape.ndim); } -int roundup(const int value, const int multiple) { +size_t roundup(const size_t value, const size_t multiple) { assert(multiple > 0); return ((value + multiple - 1) / multiple) * multiple; } diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 29875584b..07384413d 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -108,9 +108,21 @@ class Quantizer { virtual void set_quantization_params(TensorWrapper* tensor) const = 0; - virtual std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const = 0; + /*! @brief Construct a tensor with uninitialized data */ + virtual std::pair create_tensor(const std::vector& shape, + DType dtype) const = 0; + + /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor + * + * The PyTorch tensor's attributes are modified to match the + * quantizer's configuration. + */ + virtual std::pair convert_and_update_tensor( + py::object tensor) const = 0; + + /*! @brief Convert to a quantized data format */ + virtual void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) = 0; virtual ~Quantizer() = default; @@ -131,9 +143,17 @@ class NoneQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override {} - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + /*! @brief Construct a tensor with pre-initialized data */ + std::pair create_tensor(const std::vector& shape, DType dtype, + at::Tensor data) const; + + std::pair convert_and_update_tensor(py::object tensor) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; }; class Float8Quantizer : public Quantizer { @@ -149,9 +169,19 @@ class Float8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + /*! @brief Construct a tensor with pre-initialized data */ + std::pair create_tensor(const std::vector& shape, DType dtype, + std::optional data, + std::optional transpose, + std::optional scale_inv) const; + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; }; class Float8CurrentScalingQuantizer : public Quantizer { @@ -171,9 +201,29 @@ class Float8CurrentScalingQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + /*! @brief Construct a high precision tensor giving it this quantizer's amax + + Note: this member function also zeros out the amax, as it is meant to be used in conjunction with + a kernel computing the amax, which might expect the amax to be initialized to zero + */ + std::pair create_hp_tensor_with_amax(const std::vector& shape, + DType dtype); + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; + + /*! @brief Convert to a quantized data format avoiding amax computation */ + void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt); + + private: + void quantize_impl(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag, bool compute_amax); }; class Float8BlockQuantizer : public Quantizer { @@ -205,9 +255,13 @@ class Float8BlockQuantizer : public Quantizer { // Create a python Float8BlockQuantized tensor and C++ wrapper // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; @@ -222,16 +276,20 @@ class MXFP8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; std::unique_ptr convert_quantizer(py::handle quantizer); -std::vector getTensorShape(at::Tensor t); +std::vector getTensorShape(const at::Tensor& t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); @@ -383,7 +441,7 @@ void* getDataPtr(at::Tensor tensor, int offset = 0); std::vector convertShape(const NVTEShape& shape); -int roundup(const int value, const int multiple); +size_t roundup(const size_t value, const size_t multiple); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 72151f41a..9b527b161 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -13,6 +13,10 @@ #include "common.h" +class CommOverlapHelper; +class CommOverlap; +class CommOverlapP2P; + #ifdef USE_ROCM namespace transformer_engine { //dummy CommOverlapCore, CommOverlapType in rocm @@ -128,7 +132,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr, std::optional comm_type = std::nullopt, - MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false); + MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false, + float alpha = 1.0f, std::optional beta = std::nullopt); void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, std::vector A_scaling_mode, bool transa, at::Tensor B, @@ -153,42 +158,55 @@ std::optional> te_general_grouped_gemm( at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output = std::nullopt); +at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = std::nullopt); + /*************************************************************************************************** * Activations **************************************************************************************************/ +/* GELU and variants*/ py::object gelu(const at::Tensor &input, py::handle quantizer); -py::object relu(const at::Tensor &input, py::handle quantizer); +py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); py::object geglu(const at::Tensor &input, py::handle quantizer); -py::object qgeglu(const at::Tensor &input, py::handle quantizer); +py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object reglu(const at::Tensor &input, py::handle quantizer); +py::object qgelu(const at::Tensor &input, py::handle quantizer); -py::object swiglu(const at::Tensor &input, py::handle quantizer); +py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object qgelu(const at::Tensor &input, py::handle quantizer); +py::object qgeglu(const at::Tensor &input, py::handle quantizer); -py::object srelu(const at::Tensor &input, py::handle quantizer); +py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +/* ReLU and variants*/ +py::object relu(const at::Tensor &input, py::handle quantizer); py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object reglu(const at::Tensor &input, py::handle quantizer); py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object srelu(const at::Tensor &input, py::handle quantizer); py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object sreglu(const at::Tensor &input, py::handle quantizer); + +py::object dsreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +/* Silu and variants*/ +py::object silu(const at::Tensor &input, py::handle quantizer); + +py::object dsilu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +py::object swiglu(const at::Tensor &input, py::handle quantizer); + +py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + /*************************************************************************************************** * LayerNorm **************************************************************************************************/ @@ -211,6 +229,11 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma); +std::vector rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor &x, + const at::Tensor &add, const at::Tensor &rsigma, + const at::Tensor &gamma, const int sm_margin, + const bool zero_centered_gamma); + std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, py::object ln_out, py::handle quantizer, DType otype, const int sm_margin, const bool zero_centered_gamma); @@ -252,6 +275,17 @@ std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Te std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer); +/*************************************************************************************************** + * Dropout + **************************************************************************************************/ + +std::vector dropout_fwd(const py::handle &input, const float dropout_probability, + std::optional out = std::nullopt); + +py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask, + const float dropout_probability, + std::optional grad_input = std::nullopt); + /*************************************************************************************************** * Softmax **************************************************************************************************/ @@ -314,6 +348,18 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor const std::optional cu_seqlens, const int cp_size, const int cp_rank); +std::tuple fused_qkv_rope_forward( + const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs, + const std::optional start_positions, const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank); + +at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out, + const at::Tensor &v_grad_out, const at::Tensor &q_freqs, + const at::Tensor &k_freqs, + const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank); + /*************************************************************************************************** * Miscellaneous **************************************************************************************************/ @@ -447,6 +493,15 @@ void rocshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_ void rocshmem_finalize(); #endif +#ifndef USE_ROCM +/*************************************************************************************************** + * Comm+GEMM Overlap Wrappers + **************************************************************************************************/ + +void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at::Stream send_stream, + at::Stream recv_stream); +#endif // !USE_ROCM + } // namespace transformer_engine::pytorch #ifndef USE_ROCM @@ -497,7 +552,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve at::Tensor get_buffer(bool local_chunk = false, std::optional> shape = std::nullopt); - at::Stream get_communication_stream(); + std::pair get_communication_stream(); }; // CommOverlap @@ -518,7 +573,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm at::Tensor get_buffer(bool local_chunk = false, std::optional> shape = std::nullopt); - at::Stream get_communication_stream(); + std::pair get_communication_stream(); }; // CommOverlapP2P #endif // !USE_ROCM diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 2ec8ec0bf..8b0607c9e 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -15,98 +15,95 @@ namespace transformer_engine::pytorch { template py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { init_extension(); - auto my_quantizer = convert_quantizer(quantizer); - auto input_tensor = input.contiguous(); - const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); - const auto& te_input_shape = te_input.shape(); - std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); - input_shape[input_shape.size() - 1] /= shape_divisor; - auto fake_tensor_type = input.scalar_type(); - - auto [te_output, out] = - my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - - // for current scaling, we need to compute amax first and then quantize - // because cache cannot fit in the entire tensor to compute amax and quantize - // the quantizer should not need amax reduction, no process group needed here - if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // activation function might change the input data range, we need to first call the activation function - // and then find the amax and scale of that and then do the quantization - // get a NoneQuantizer to calculate amax of activation output - auto my_quantizer_none = std::make_unique(py::none()); - auto [te_output_act, out_act] = - my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - -#ifdef __HIP_PLATFORM_AMD__ - at::Tensor ws = allocate_amax_workspace(te_output_act); - TensorWrapper tw = makeTransformerEngineTensor(ws); -#endif - NVTE_SCOPED_GIL_RELEASE({ - act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream()); - // use te_output_act as input to the compute amax and find the amax of activated tensor -#ifdef __HIP_PLATFORM_AMD__ - nvte_compute_amax_with_workspace(te_output_act.data(), te_output.data(), - tw.data(), at::cuda::getCurrentCUDAStream()); -#else - nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); -#endif - }); - - // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(my_quantizer.get()); - if (my_quantizer_cs->with_amax_reduction) { - NVTE_ERROR( - "per-tensor current scaling amax reduction is not supported in activation functions."); - } - QuantizationConfigWrapper quant_config; - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); + // Input tensor + auto input_tensor = input.contiguous(); + const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + + // Construct output tensor + auto quantizer_cpp = convert_quantizer(quantizer); + const auto input_shape = input_cpp.shape(); + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + output_shape.back() /= shape_divisor; + auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); + auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); + + // Compute activation + if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || + detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Compute activation directly + NVTE_SCOPED_GIL_RELEASE( + { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // Compute activation in high-precision fused together with amax, then quantize. - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(te_output.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); - nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - }); - } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { - // sanity check, since activation fusion is not supported for blockwise quantization yet - // need to raise an error here instead of silently going into act_func with wrong numerics - NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet."); + auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); + auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE( + { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); } else { + // Compute activation in high-precision, then quantize + + auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE( - { act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); }); + { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + quantizer_cpp->quantize(temp_cpp, out_cpp); } - return out; + return out_py; } -template -py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input, +template +py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, py::handle quantizer) { init_extension(); - auto my_quantizer = convert_quantizer(quantizer); - auto input_tensor = input.contiguous(); - auto grad_tensor = grad.contiguous(); - - const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); - const TensorWrapper& te_grad = makeTransformerEngineTensor(grad_tensor); - const auto& te_input_shape = te_input.shape(); - std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); - auto fake_tensor_type = input.scalar_type(); - - auto [te_output, out] = - my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - NVTE_SCOPED_GIL_RELEASE({ - act_func(te_grad.data(), te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - }); + // Grad output and input tensors + auto grad_output_tensor = grad_output.contiguous(); + auto input_tensor = input.contiguous(); + const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor); + const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + + // Construct grad input tensor + auto quantizer_cpp = convert_quantizer(quantizer); + const auto input_shape_te = input_cpp.shape(); + const std::vector input_shape(input_shape_te.data, + input_shape_te.data + input_shape_te.ndim); + auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); + auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); + + // Compute activation backward + if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || + detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Compute activation backward directly + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), + at::cuda::getCurrentCUDAStream()); + }); + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // Compute activation backward in high-precision fused together with amax, then quantize. + auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); + auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); + }); + quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); + } else { + // Compute activation backward in high-precision, then quantize + auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); + }); + quantizer_cpp->quantize(temp_cpp, grad_input_cpp); + } - return out; + return grad_input_py; } +/* GELU and variants*/ py::object gelu(const at::Tensor& input, py::handle quantizer) { return activation_helper(input, quantizer); } @@ -115,30 +112,39 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua return dactivation_helper(grad, input, quantizer); } -py::object relu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); +py::object geglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -py::object geglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); +py::object qgelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); } -py::object qgeglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); +py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object qgeglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { return dactivation_helper(grad, input, quantizer); } +/* ReLU and variants*/ +py::object relu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} + +py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + py::object reglu(const at::Tensor& input, py::handle quantizer) { return activation_helper(input, quantizer, 2); } @@ -147,28 +153,36 @@ py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle qu return dactivation_helper(grad, input, quantizer); } -py::object swiglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); +py::object srelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); } -py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -py::object qgelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); +py::object sreglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -py::object srelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); +/* Silu and variants*/ +py::object silu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); } -py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } +py::object swiglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); +} + +py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 6f6f82725..064da8a67 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -28,9 +28,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, auto freqs_cu = makeTransformerEngineTensor(freqs); auto output_cu = makeTransformerEngineTensor(output); - auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor + auto start_positions_cu = TensorWrapper(); // empty start_positions tensor if (start_positions) { start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor"); } if (qkv_format == NVTE_QKV_Format::NVTE_THD) { @@ -102,6 +103,65 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, return output; } +std::tuple fused_qkv_rope_forward( + const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs, + const std::optional start_positions, const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, + const int cp_rank) { + TORCH_CHECK(q_freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(q_freqs.size(1) == 1 && q_freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(q_freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + TORCH_CHECK(k_freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(k_freqs.size(1) == 1 && k_freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(k_freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + // output + auto act_options = at::TensorOptions().dtype(qkv_input.scalar_type()).device(qkv_input.device()); + auto q_out_size = qkv_input.sizes().vec(); + q_out_size[2] = q_out_size[2] * qkv_split_arg_list[0] / qkv_split_arg_list[1]; + q_out_size[3] = qkv_split_arg_list[1]; + auto q_out = at::empty(q_out_size, act_options); + auto k_out_size = qkv_input.sizes().vec(); + k_out_size[3] = qkv_split_arg_list[1]; + auto k_out = at::empty(k_out_size, act_options); + auto v_out_size = qkv_input.sizes().vec(); + v_out_size[3] = qkv_split_arg_list[2]; + auto v_out = at::empty(v_out_size, act_options); + + auto qkv_cu = makeTransformerEngineTensor(qkv_input); + auto q_freqs_cu = makeTransformerEngineTensor(q_freqs); + auto k_freqs_cu = makeTransformerEngineTensor(k_freqs); + auto q_out_cu = makeTransformerEngineTensor(q_out); + auto k_out_cu = makeTransformerEngineTensor(k_out); + auto v_out_cu = makeTransformerEngineTensor(v_out); + + auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor + if (start_positions) { + start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + } + + TORCH_CHECK(qkv_input.dim() == 4, "expected 4D input tensor"); + TORCH_CHECK(qkv_input.is_contiguous(), "input tensor must be contiguous"); + + const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD; + const int s = is_sbhd ? qkv_input.size(0) : qkv_input.size(1); + const int b = is_sbhd ? qkv_input.size(1) : qkv_input.size(0); + const int h = qkv_input.size(2); + const int d = qkv_split_arg_list[2]; + const int d2 = q_freqs.size(3); + + nvte_fused_qkv_rope_forward(qkv_cu.data(), q_freqs_cu.data(), k_freqs_cu.data(), + start_positions_cu.data(), q_out_cu.data(), k_out_cu.data(), + v_out_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, h, + d, d2, qkv_split_arg_list[0], qkv_split_arg_list[1], + qkv_split_arg_list[2], at::cuda::getCurrentCUDAStream()); + + return std::make_tuple(q_out, k_out, v_out); +} + at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, const NVTE_QKV_Format qkv_format, const bool interleaved, const std::optional cu_seqlens, const int cp_size, @@ -193,4 +253,42 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor return input_grads; } +at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out, + const at::Tensor &v_grad_out, const at::Tensor &q_freqs, + const at::Tensor &k_freqs, + const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank) { + auto act_options = + at::TensorOptions().dtype(q_grad_out.scalar_type()).device(q_grad_out.device()); + auto qkv_grad_size = q_grad_out.sizes().vec(); + auto total_hd = + (q_grad_out.size(2) + k_grad_out.size(2) + v_grad_out.size(2)) * q_grad_out.size(3); + auto total_d = qkv_split_arg_list[0] + qkv_split_arg_list[1] + qkv_split_arg_list[2]; + qkv_grad_size[2] = total_hd / total_d; + qkv_grad_size[3] = total_d; + auto qkv_grad_input = at::empty(qkv_grad_size, act_options); + const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD; + const int s = is_sbhd ? q_grad_out.size(0) : q_grad_out.size(1); + const int b = is_sbhd ? q_grad_out.size(1) : q_grad_out.size(0); + const int h = qkv_grad_input.size(2); + const int d = qkv_split_arg_list[2]; + const int d2 = q_freqs.size(3); + + auto q_grad_out_cu = makeTransformerEngineTensor(q_grad_out); + auto k_grad_out_cu = makeTransformerEngineTensor(k_grad_out); + auto v_grad_out_cu = makeTransformerEngineTensor(v_grad_out); + auto q_freqs_cu = makeTransformerEngineTensor(q_freqs); + auto k_freqs_cu = makeTransformerEngineTensor(k_freqs); + auto qkv_grad_cu = makeTransformerEngineTensor(qkv_grad_input); + + nvte_fused_qkv_rope_backward(q_grad_out_cu.data(), k_grad_out_cu.data(), v_grad_out_cu.data(), + q_freqs_cu.data(), k_freqs_cu.data(), qkv_grad_cu.data(), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list[0], + qkv_split_arg_list[1], qkv_split_arg_list[2], + at::cuda::getCurrentCUDAStream()); + + return qkv_grad_input; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 71a8062b1..6d835a5c9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -18,7 +18,7 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s auto max_tokens = shape[0]; auto fcd_size = 1; - for (int i = 1; i <= shape.size(); i++) { + for (size_t i = 1; i <= shape.size(); i++) { fcd_size *= shape[i]; } @@ -103,8 +103,20 @@ std::vector fused_attn_fwd( auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; py::object o_python, s_python; - std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // Initialize FP8 tensor with scale-inverse + auto *O_quantizer_fp8 = dynamic_cast(O_quantizer.get()); + auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); + NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt, + std::nullopt, std::nullopt); + std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, + std::nullopt, std::nullopt); + } else { + std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + } auto o_shape_int64 = std::vector{o_shape.begin(), o_shape.end()}; // construct NVTE tensors @@ -284,8 +296,20 @@ std::vector fused_attn_bwd( py::object s_python, dp_python; std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); std::unique_ptr dP_quantizer = convert_quantizer(dp_quantizer); - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); - std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); + + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); + auto *dP_quantizer_fp8 = dynamic_cast(dP_quantizer.get()); + NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, + std::nullopt, std::nullopt); + std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, + std::nullopt, std::nullopt); + } else { + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); + } std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); @@ -374,9 +398,22 @@ std::vector fused_attn_bwd( default: NVTE_ERROR("QKV layout not supported!"); } - std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); - std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, dK); - std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, dV); + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + auto *fp8_quantizer = dynamic_cast(dQKV_quantizer.get()); + NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8"); + std::tie(te_dQ, py_dQ) = + fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt); + std::tie(te_dK, py_dK) = + fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt); + std::tie(te_dV, py_dV) = + fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt); + } else { + auto *none_quantizer = dynamic_cast(dQKV_quantizer.get()); + NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8"); + std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); + std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK); + std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV); + } // construct NVTE tensors if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index bfa878c30..f65614d07 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -6,91 +6,223 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include + +#include +#include + #include "common.h" +#include "extensions.h" #include "pybind.h" #include "transformer_engine/cast.h" #include "transformer_engine/transformer_engine.h" -namespace transformer_engine::pytorch { +namespace transformer_engine { +namespace pytorch { -std::vector bgrad_quantize(const at::Tensor& input, py::handle py_quantizer) { - auto quantizer = convert_quantizer(py_quantizer); +std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle quantizer) { + using namespace transformer_engine::pytorch::detail; + init_extension(); - auto input_tensor = makeTransformerEngineTensor(input); + // Grad output tensor + auto grad_output_torch = grad_output.contiguous(); + const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); + const auto shape = getTensorShape(grad_output_torch); + auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); - auto dbias = allocateTorchTensor(input.size(-1), input_tensor.dtype()); + // Construct grad bias tensor + const int64_t bias_size = static_cast(shape.back()); + auto grad_bias_torch = allocateTorchTensor(bias_size, grad_output_dtype); + auto grad_bias_nvte = makeTransformerEngineTensor(grad_bias_torch); - std::vector output_shape; - for (auto s : input.sizes()) { - output_shape.emplace_back(static_cast(s)); + // Unquantized impl only requires computing grad bias + if (quantizer.is_none()) { + if (product(shape) == 0) { + grad_bias_torch.zero_(); + } else { + at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0}); + } + return {py::cast(std::move(grad_bias_torch)), py::cast(std::move(grad_output_torch))}; } - auto [out_tensor, out] = quantizer->create_tensor(output_shape, input_tensor.dtype()); - // Return immediately if tensors are empty - if (product(output_shape) == 0) { - return {py::cast(dbias.zero_()), out}; + // Construct grad input tensor + auto quantizer_cpp = convert_quantizer(quantizer); + auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(shape, grad_output_dtype); + + // Trivial impl if tensors are empty + if (product(shape) == 0) { + grad_bias_torch.zero_(); + return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; + } + + // Unfused impl if quantizer is not supported + const bool with_fused_dbias_quantize_kernel = + detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr()); + if (!with_fused_dbias_quantize_kernel) { + at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0}); + quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte); + return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; } - auto dbias_tensor = makeTransformerEngineTensor(dbias); - // Query workspace size and allocate workspace - transformer_engine::TensorWrapper workspace; + // Query workspace size + TensorWrapper workspace_nvte; + at::Tensor workspace_torch; + auto stream = at::cuda::getCurrentCUDAStream(); NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_quantize_dbias(grad_output_nvte.data(), grad_input_nvte.data(), grad_bias_nvte.data(), + workspace_nvte.data(), stream); }); - void* workspace_data_ptr = nullptr; - if (workspace.shape().ndim > 0) { - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace_data_ptr = workspace_data.data_ptr(); - } - workspace = makeTransformerEngineTensor(workspace_data_ptr, workspace.shape(), workspace.dtype()); - - // Launch kernel - if (detail::IsFloat8CurrentScalingQuantizers(py_quantizer.ptr())) { - // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(quantizer.get()); -#ifdef __HIP_PLATFORM_AMD__ - at::Tensor ws = allocate_amax_workspace(input_tensor); - TensorWrapper tw = makeTransformerEngineTensor(ws); -#endif - - NVTE_SCOPED_GIL_RELEASE({ -#ifdef __HIP_PLATFORM_AMD__ - nvte_compute_amax_with_workspace(input_tensor.data(), out_tensor.data(), - tw.data(), - at::cuda::getCurrentCUDAStream()); -#else - nvte_compute_amax(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream()); -#endif - }); - // check if we need to do amax reudction (depending on model parallel configs) - if (my_quantizer_cs->with_amax_reduction) { - c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; - // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; - std::vector tensors = {amax_tensor_torch}; - // allreduce amax tensor - c10d::AllreduceOptions allreduce_opts; - allreduce_opts.reduceOp = c10d::ReduceOp::MAX; - process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); - } - QuantizationConfigWrapper quant_config; - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(out_tensor.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - }); - // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - out_tensor.set_amax(nullptr, DType::kFloat32, out_tensor.defaultShape); + // Allocate workspace + if (workspace_nvte.ndim() > 0 && workspace_nvte.numel() > 0) { + workspace_torch = allocateSpace(workspace_nvte.shape(), workspace_nvte.dtype()); + workspace_nvte = makeTransformerEngineTensor(workspace_torch.data_ptr(), workspace_nvte.shape(), + workspace_nvte.dtype()); } + + // Launch fused kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_quantize_dbias(grad_output_nvte.data(), grad_input_nvte.data(), grad_bias_nvte.data(), + workspace_nvte.data(), stream); }); - return {py::cast(dbias), out}; + return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; +} + +namespace { + +std::vector dact_dbias( + void (*dact_dbias_func)(const NVTETensor, const NVTETensor, NVTETensor, NVTETensor, NVTETensor, + cudaStream_t), + void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t), + at::Tensor grad_output_torch, at::Tensor act_input_torch, py::handle quantizer_py) { + using namespace transformer_engine::pytorch::detail; + init_extension(); + + // Grad output and activation input tensors + grad_output_torch = grad_output_torch.contiguous(); + const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); + const auto output_shape = getTensorShape(grad_output_torch); + auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); + act_input_torch = act_input_torch.contiguous(); + const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); + const auto input_shape = getTensorShape(act_input_torch); + + // Construct tensors + auto quantizer_cpp = convert_quantizer(quantizer_py); + auto [grad_input_nvte, grad_input_py] = + quantizer_cpp->create_tensor(input_shape, grad_output_dtype); + const int64_t bias_size = static_cast(input_shape.back()); + auto grad_bias_torch = allocateTorchTensor(bias_size, grad_output_dtype); + auto grad_bias_nvte = makeTransformerEngineTensor(grad_bias_torch); + + // Return immediately if tensors are empty + if (product(output_shape) == 0) { + grad_bias_torch.zero_(); + return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; + } + + // Choose implementation + enum class Impl { UNFUSED, FUSED_DACT_DBIAS_QUANTIZE, FUSED_DACT_AMAX }; + Impl impl = Impl::UNFUSED; + if (detail::IsFloat8Quantizers(quantizer_py.ptr()) || + detail::IsMXFP8Quantizers(quantizer_py.ptr())) { + impl = Impl::FUSED_DACT_DBIAS_QUANTIZE; + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { + impl = Impl::FUSED_DACT_AMAX; + } + + // Perform compute + auto stream = at::cuda::getCurrentCUDAStream(); + switch (impl) { + case Impl::UNFUSED: + // Unfused dact, dbias, quantize + { + auto [temp_nvte, temp_py] = + NoneQuantizer(py::none()).create_tensor(input_shape, grad_output_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); + }); + const auto temp_torch = temp_py.cast(); + at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); + quantizer_cpp->quantize(temp_nvte, grad_input_nvte); + break; + } + case Impl::FUSED_DACT_DBIAS_QUANTIZE: + // Fused dact-dbias-quantize kernel + { + // Query workspace size + TensorWrapper workspace_nvte; + NVTE_SCOPED_GIL_RELEASE({ + dact_dbias_func(grad_output_nvte.data(), act_input_nvte.data(), grad_input_nvte.data(), + grad_bias_nvte.data(), workspace_nvte.data(), stream); + }); + + // Allocate workspace + at::Tensor workspace_torch; + if (workspace_nvte.ndim() > 0 && workspace_nvte.numel() > 0) { + workspace_torch = allocateSpace(workspace_nvte.shape(), workspace_nvte.dtype()); + workspace_nvte = makeTransformerEngineTensor( + workspace_torch.data_ptr(), workspace_nvte.shape(), workspace_nvte.dtype()); + } + + // Launch kernel + NVTE_SCOPED_GIL_RELEASE({ + dact_dbias_func(grad_output_nvte.data(), act_input_nvte.data(), grad_input_nvte.data(), + grad_bias_nvte.data(), workspace_nvte.data(), stream); + }); + break; + } + case Impl::FUSED_DACT_AMAX: + // Fused dact-amax kernel, unfused dbias and quantize + { + auto *quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(quantizer_cpp_cs != nullptr, + "Invalid quantizer for fused dact-amax kernel impl"); + auto [temp_nvte, temp_py] = + quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, grad_output_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); + }); + const auto temp_torch = temp_py.cast(); + at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); + quantizer_cpp_cs->quantize_with_amax(temp_nvte, grad_input_nvte); + break; + } + default: + NVTE_ERROR("Invalid implementation"); + } + + return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; +} + +} // namespace + +std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer) { + return dact_dbias(nvte_quantize_dbias_dgelu, nvte_dgelu, grad_output, act_input, quantizer); +} + +std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer) { + return dact_dbias(nvte_quantize_dbias_dsilu, nvte_dsilu, grad_output, act_input, quantizer); +} + +std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer) { + return dact_dbias(nvte_quantize_dbias_drelu, nvte_drelu, grad_output, act_input, quantizer); +} + +std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer) { + return dact_dbias(nvte_quantize_dbias_dqgelu, nvte_dqgelu, grad_output, act_input, quantizer); +} + +std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer) { + return dact_dbias(nvte_quantize_dbias_dsrelu, nvte_dsrelu, grad_output, act_input, quantizer); } -} // namespace transformer_engine::pytorch +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 877c4f6e0..c940181b0 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -30,73 +30,6 @@ std::vector get_tensor_shape(const TensorWrapper &tensor) { return std::vector(shape.data, shape.data + shape.ndim); } -void quantize_impl(const TensorWrapper &input, py::handle &quantizer_py, - std::unique_ptr &quantizer_cpp, TensorWrapper &output, - TensorWrapper &noop_flag) { - // Check tensor dims - NVTE_CHECK(get_tensor_shape(input) == get_tensor_shape(output), - "Input tensor (shape=", get_tensor_shape(input), - ") and output tensor (shape=", get_tensor_shape(output), ") do not match"); - if (input.numel() == 0) { - return; - } - - // Recipe-specific configuration - QuantizationConfigWrapper quant_config; - quant_config.set_noop_tensor(noop_flag.data()); - - if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { - // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(quantizer_cpp.get()); -#ifdef __HIP_PLATFORM_AMD__ - at::Tensor ws = allocate_amax_workspace(input); - TensorWrapper tw = makeTransformerEngineTensor(ws); -#endif - NVTE_SCOPED_GIL_RELEASE({ -#ifdef __HIP_PLATFORM_AMD__ - nvte_compute_amax_with_workspace(input.data(), output.data(), - tw.data(), - at::cuda::getCurrentCUDAStream()); -#else - nvte_compute_amax(input.data(), output.data(), at::cuda::getCurrentCUDAStream()); -#endif - }); - // check if we need to do amax reudction (depending on model parallel configs) - if (my_quantizer_cs->with_amax_reduction) { - c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; - // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; - std::vector tensors = {amax_tensor_torch}; - // allreduce amax tensor - c10d::AllreduceOptions allreduce_opts; - allreduce_opts.reduceOp = c10d::ReduceOp::MAX; - process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); - } - // this config is used for cs scaling factor computation - // because compute scale is cannot be fused with quantize kernel - // so in nvte_quantize_v2 with current scaling, the quant config is not used again - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(output.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); - // set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel - output.set_amax(nullptr, DType::kFloat32, output.defaultShape); - } else if (detail::IsFloat8BlockwiseQuantizers(quantizer_py.ptr())) { - auto my_quantizer_bw = static_cast(quantizer_cpp.get()); - quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); - if (my_quantizer_bw->all_gather_usage) { - quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); - } - } - - // Perform quantization - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(input.data(), output.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); -} - } // namespace py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, @@ -116,18 +49,17 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob const auto fake_dtype = input_cpp.dtype(); std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); } else { - output_py = output; - output_cpp = makeTransformerEngineTensor(output_py, quantizer); + std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output); } // Initialize no-op flag - TensorWrapper noop_flag_cpp; + std::optional noop_flag_cpp; if (noop_flag.has_value()) { noop_flag_cpp = makeTransformerEngineTensor(*noop_flag); } // Perform quantization - quantize_impl(input_cpp, quantizer, quantizer_cpp, output_cpp, noop_flag_cpp); + quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); return output_py; } @@ -197,10 +129,8 @@ void multi_tensor_quantize_impl(const std::vector &input_list, }); } else { // Quantize kernels individually - TensorWrapper dummy_noop_flag; for (size_t i = 0; i < num_tensors; ++i) { - quantize_impl(input_list[i], quantizer_py_list[i], quantizer_cpp_list[i], output_list[i], - dummy_noop_flag); + quantizer_cpp_list[i]->quantize(input_list[i], output_list[i]); } } } @@ -277,11 +207,8 @@ std::tuple, std::vector> bulk_allocate_fp auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); - // in the case where full buffer is empty because local rank receives no tokens for all the experts - // then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob - // but in the case where some experts receive tokens, some not, we want to leverage from_blob - // as much as possible to avoid CPU overhead - if (buffer->data_ptr() == nullptr) { + bool is_empty_shape = product(shape) == 0; + if (buffer->data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); } return at::from_blob( @@ -431,11 +358,8 @@ std::tuple, std::vector> bulk_allocate_mx auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); - // in the case where full buffer is empty because local rank receives no tokens for all the experts - // then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob - // but in the case where some experts receive tokens, some not, we want to leverage from_blob - // as much as possible to avoid CPU overhead - if (buffer->data_ptr() == nullptr) { + bool is_empty_shape = product(shape) == 0; + if (buffer->data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); } return at::from_blob( @@ -470,11 +394,8 @@ std::tuple, std::vector> bulk_allocate_mx } // Allocate full buffer - // TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel auto buffer = std::make_shared( - at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - // auto buffer = std::make_shared( - // at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -513,11 +434,8 @@ std::tuple, std::vector> bulk_allocate_mx } // Allocate full buffer - // TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel auto buffer = std::make_shared( - at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - // auto buffer = std::make_shared( - // at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -665,66 +583,5 @@ std::vector split_quantize(const at::Tensor &tensor, return output_py_list; } -template -std::vector dbias_dact(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - init_extension(); - auto my_quantizer = convert_quantizer(quantizer); - - auto grad_tensor = makeTransformerEngineTensor(grad_output); - - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_tensor.dtype()); - auto act_input_tensor = makeTransformerEngineTensor(act_input); - - const auto &shape = convertShape(grad_tensor.shape()); - auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype()); - - auto dbias_tensor = makeTransformerEngineTensor(grad_bias); - - // Query workspace size and allocate workspace - transformer_engine::TensorWrapper workspace; - NVTE_SCOPED_GIL_RELEASE({ - func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - }); - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Launch kernel - NVTE_SCOPED_GIL_RELEASE({ - func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - }); - - return {py::cast(grad_bias), dact}; -} - -std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dbias_dact(grad_output, act_input, quantizer); -} - -std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dbias_dact(grad_output, act_input, quantizer); -} - -std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dbias_dact(grad_output, act_input, quantizer); -} - -std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dbias_dact(grad_output, act_input, quantizer); -} - -std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dbias_dact(grad_output, act_input, quantizer); -} - } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 737c1d707..4aa2df2c9 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -218,8 +218,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional CommOverlap::get_communication_stream() { + // Return the same stream for both send and recv + return {at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device()), + at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device())}; } /*************************************************************************************************** @@ -307,7 +309,15 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional CommOverlapP2P::get_communication_stream() { + return {at::cuda::getStreamFromExternal(_stream_send[0], at::cuda::current_device()), + at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device())}; +} + +void transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm( + CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream) { + auto main_stream = at::cuda::getCurrentCUDAStream(); + allgather_communicator.bulk_overlap_external_ag(at::cuda::CUDAStream(send_stream), + at::cuda::CUDAStream(recv_stream), main_stream); } #endif // !USE_ROCM diff --git a/transformer_engine/pytorch/csrc/extensions/dropout.cpp b/transformer_engine/pytorch/csrc/extensions/dropout.cpp new file mode 100644 index 000000000..d009cf2f3 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/dropout.cpp @@ -0,0 +1,92 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/dropout.h" + +#include +#include + +#include + +#include "../common.h" +#include "../extensions.h" +#include "../pybind.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { +namespace pytorch { + +std::vector dropout_fwd(const py::handle &input, float dropout_probability, + std::optional out) { + using namespace transformer_engine::pytorch::detail; + + // Input tensor + const TensorWrapper input_nvte = makeTransformerEngineTensor(input, py::none()); + + // Allocate output tensor if needed + if (!out) { + at::ScalarType dtype = GetATenDType(input_nvte.dtype()); + if (dtype == at::kFloat8_e4m3fn || dtype == at::kFloat8_e5m2 || + dtype == at::kFloat8_e4m3fnuz || dtype == at::kFloat8_e5m2fnuz) { + dtype = input.attr("dtype").cast(); + } + const auto shape_uint64 = convertShape(input_nvte.shape()); + const std::vector shape_int64(shape_uint64.begin(), shape_uint64.end()); + const auto opts = at::TensorOptions().dtype(dtype).device(torch::kCUDA); + out = at::empty(shape_int64, opts); + } + TensorWrapper out_nvte = makeTransformerEngineTensor(*out); + + // Mask tensor + auto mask_pyt = allocateTorchTensor(input_nvte.numel() / 8, DType::kByte); + auto mask_nvte = makeTransformerEngineTensor(mask_pyt); + + // RNG state tensor + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState philox_args; + { + std::lock_guard lock(gen->mutex_); + constexpr int64_t rng_elts_per_thread = 4; + philox_args = gen->philox_cuda_state(rng_elts_per_thread); + } + auto rng_state_pyt = allocateTorchTensor(2, DType::kInt64); + NVTE_SCOPED_GIL_RELEASE({ + nvte_extract_seed_and_offset( + reinterpret_cast(rng_state_pyt.data_ptr()), philox_args.captured_, + philox_args.seed_.ptr, philox_args.seed_.val, philox_args.offset_.ptr, + philox_args.offset_.val, philox_args.offset_intragraph_, at::cuda::getCurrentCUDAStream()); + }); + auto rng_state_nvte = makeTransformerEngineTensor(rng_state_pyt); + + // Launch kernel + NVTE_SCOPED_GIL_RELEASE({ + nvte_dropout_fwd(input_nvte.data(), out_nvte.data(), mask_nvte.data(), rng_state_nvte.data(), + dropout_probability, at::cuda::getCurrentCUDAStream()); + }); + + return {py::cast(std::move(*out)), py::cast(mask_pyt)}; +} + +py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask, + const float dropout_probability, std::optional grad_input) { + const auto grad_output_nvte = makeTransformerEngineTensor(grad_output); + const auto mask_nvte = makeTransformerEngineTensor(mask); + if (!grad_input) { + grad_input = at::empty_like(grad_output); + } + auto grad_input_nvte = makeTransformerEngineTensor(*grad_input); + NVTE_SCOPED_GIL_RELEASE({ + nvte_dropout_bwd(grad_output_nvte.data(), mask_nvte.data(), grad_input_nvte.data(), + dropout_probability, at::cuda::getCurrentCUDAStream()); + }); + return py::cast(std::move(*grad_input)); +} + +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index d8696c14d..b637d49c7 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -94,7 +94,9 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, CommOverlapCore* comm_overlap, std::optional comm_type, MaybeTensor extra_output, - bool bulk_overlap) { + bool bulk_overlap, float alpha, std::optional beta) { + using namespace transformer_engine::pytorch::detail; + // Input tensors NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); @@ -112,10 +114,23 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); + // Check scaling factors + if (accumulate) { + if (!beta) { + beta = 1.0f; + } + } else { + if (!beta) { + beta = 0.0f; + } + NVTE_CHECK(beta == 0.0, "Trying to use non-zero beta while not accumulating ", + "into D tensor. Beta has nothing to be applied to."); + } + + DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); // Output tensor TensorWrapper D_tensor; if (D.is_none()) { - DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); } else { D_tensor = makeTransformerEngineTensor(D, quantizer); @@ -128,12 +143,35 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } + // maintain unquantized tensor in case we need unfused quantization support. + TensorWrapper unquantized_D_tensor; + py::object unquantized_out; + // Unfused quantization is needed in the following cases + // 1. Inputs: BF16, Output: FP8 (GEMM output has to be BF16, so FP8 quantization needed after that) + // 2. Inputs: FP8, Output: FP8 (For any quantization apart from delayed scaling, + // GEMM Output needs to be in BF16, to allow for unfused quantization) + bool unfused_quantization_needed = !quantizer.is_none(); + if (low_precision) { + // At the moment, only use-case for fused GEMM: + // Delayed scaling quantizer with per-tensor scaling inputs + bool is_per_tensor_scaling_input = IsFloat8Tensor(A.ptr()) || IsFloat8Tensor(B.ptr()); + if (IsFloat8Quantizers(quantizer.ptr()) && is_per_tensor_scaling_input) + unfused_quantization_needed = false; + } + + if (unfused_quantization_needed) { + NoneQuantizer q{none}; + std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype); + } + TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; + // Bias tensor TensorWrapper bias_tensor; MaybeTensor bias_grad = std::nullopt; if (bias.has_value()) { if (grad) { - auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA); + auto opts = + torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA); bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); bias_tensor = makeTransformerEngineTensor(*bias_grad); } else { @@ -146,7 +184,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Activation input tensor MaybeTensor pre_gelu_out = std::nullopt; - DType gelu_type = low_precision ? bias_type : D_tensor.dtype(); + DType gelu_type = low_precision ? bias_type : out_tensor.dtype(); if (gelu) { if (!grad) { auto dtype = GetATenDType(gelu_type); @@ -204,7 +242,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Direct GEMM call to the correct overlap if (bulk_overlap) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, comm_type.value(), extra_output_tensor, main_stream); @@ -212,14 +250,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else if (comm_type.value() == CommOverlapType::AG) { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -228,14 +266,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -248,14 +286,15 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), - te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), - accumulate, use_split_accumulator, num_math_sms, main_stream); + nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), + bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, + te_workspace.data(), alpha, *beta, use_split_accumulator, + num_math_sms, main_stream); }); } } else { - if (D_tensor.numel() != 0 && !accumulate) { - D_tensor.zero_(main_stream); + if (out_tensor.numel() != 0 && !accumulate) { + out_tensor.zero_(main_stream); } if (bias.has_value()) { if (bias->numel() != 0 && grad) { @@ -263,7 +302,11 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - + if (unfused_quantization_needed) { + // Quantize the output + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + my_quantizer->quantize(unquantized_D_tensor, D_tensor); + } // Pack outputs std::vector out; out.emplace_back(std::move(D)); @@ -336,12 +379,8 @@ std::optional> te_general_grouped_gemm( size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, te_pre_gelu_out_vector, te_workspace_vector; - std::vector wrappers; + std::vector te_A_wrappers, te_B_wrappers, wrappers; std::vector D_vectors; -#ifndef USE_ROCM - // Keep the swizzled scaling factor tensors alive during the GEMMs. - std::vector> swizzled_scale_inverses_list; -#endif auto none = py::none(); @@ -408,12 +447,6 @@ std::optional> te_general_grouped_gemm( continue; } -#ifndef USE_ROCM - // Optionally swizzle the scaling factors - swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa))); - swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb))); -#endif - auto te_D = makeTransformerEngineTensor(out_tensor); auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); @@ -433,25 +466,33 @@ std::optional> te_general_grouped_gemm( te_bias_vector.emplace_back(te_bias.data()); te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); - wrappers.emplace_back(std::move(te_A)); - wrappers.emplace_back(std::move(te_B)); + te_A_wrappers.emplace_back(std::move(te_A)); + te_B_wrappers.emplace_back(std::move(te_B)); wrappers.emplace_back(std::move(te_D)); wrappers.emplace_back(std::move(te_bias)); wrappers.emplace_back(std::move(te_pre_gelu_out)); } + +#ifndef USE_ROCM + // Optionally swizzle the scaling factors + // Keep the swizzled scaling factor tensors alive during the GEMMs. + auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa); + auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb); +#endif + for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), std::vector{workspaceSize}, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); wrappers.emplace_back(std::move(wsp)); } + // For now, we only have multi-stream cublas backend. NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), - te_bias_vector.data(), te_pre_gelu_out_vector.data(), - te_A_vector.size(), transa, transb, grad, - te_workspace_vector.data(), accumulate, use_split_accumulator, - math_sm_count, at::cuda::getCurrentCUDAStream()); + nvte_multi_tensor_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), + te_bias_vector.data(), te_pre_gelu_out_vector.data(), te_A_vector.size(), + transa, transb, grad, te_workspace_vector.data(), accumulate, + use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); }); return bias; } diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp index 21d3e0574..acf04900e 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp @@ -16,11 +16,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_adam_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction, - weight_decay, device_id, at::cuda::getCurrentCUDAStream()); + weight_decay, at::cuda::getCurrentCUDAStream()); } void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, @@ -31,12 +30,10 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_adam_param_remainder_cuda( chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1, - beta2, epsilon, step, mode, bias_correction, weight_decay, device_id, - at::cuda::getCurrentCUDAStream()); + beta2, epsilon, step, mode, bias_correction, weight_decay, at::cuda::getCurrentCUDAStream()); } void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, @@ -47,12 +44,11 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_adam_fp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction, weight_decay, static_cast(fp8_dtype), - device_id, at::cuda::getCurrentCUDAStream()); + at::cuda::getCurrentCUDAStream()); } void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, @@ -67,12 +63,11 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, auto lr_cu = makeTransformerEngineTensor(lr); auto step_cu = makeTransformerEngineTensor(step); auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_adam_capturable_cuda( chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay, - inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream()); + inv_scale_cu.data(), at::cuda::getCurrentCUDAStream()); } void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag, @@ -87,12 +82,11 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl auto lr_cu = makeTransformerEngineTensor(lr); auto step_cu = makeTransformerEngineTensor(step); auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_adam_capturable_master_cuda( chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay, - inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream()); + inv_scale_cu.data(), at::cuda::getCurrentCUDAStream()); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp index 290f70b57..8a1a34698 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp @@ -14,11 +14,10 @@ void multi_tensor_compute_scale_and_scale_inv_cuda( auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_compute_scale_and_scale_inv_cuda( chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8, - force_pow_2_scales, epsilon, device_id, at::cuda::getCurrentCUDAStream()); + force_pow_2_scales, epsilon, at::cuda::getCurrentCUDAStream()); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp index 1e8eb44d9..d33a2520e 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp @@ -43,12 +43,11 @@ std::tuple multi_tensor_l2norm_cuda( auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor); auto ret_cu = makeTransformerEngineTensor(ret); auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_l2norm_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(), per_tensor, - max_chunks_per_tensor, device_id, at::cuda::getCurrentCUDAStream()); + max_chunks_per_tensor, at::cuda::getCurrentCUDAStream()); return std::tuple(ret, ret_per_tensor); } @@ -91,13 +90,11 @@ std::tuple multi_tensor_unscale_l2norm_cuda( auto ret_cu = makeTransformerEngineTensor(ret); auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor); auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_unscale_l2norm_cuda( chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(), - inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, device_id, - at::cuda::getCurrentCUDAStream()); + inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, at::cuda::getCurrentCUDAStream()); return std::tuple(ret, ret_per_tensor); } diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp index ba33f04bf..2db936f84 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp @@ -13,10 +13,9 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, - num_tensors, scale, device_id, at::cuda::getCurrentCUDAStream()); + num_tensors, scale, at::cuda::getCurrentCUDAStream()); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp index de3209535..2c6a6b7c4 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp @@ -15,11 +15,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_sgd_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, wd, momentum, dampening, lr, nesterov, first_run, - wd_after_momentum, scale, device_id, at::cuda::getCurrentCUDAStream()); + wd_after_momentum, scale, at::cuda::getCurrentCUDAStream()); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index f834f94af..728d39cbd 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -112,8 +112,15 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe TensorWrapper unquantized_out_cu; py::object unquantized_out; if (force_unfused_kernel) { - NoneQuantizer q{none}; - std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); + std::tie(unquantized_out_cu, unquantized_out) = + my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); + } else { + NoneQuantizer q{none}; + std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + } } TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; @@ -141,56 +148,13 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { - QuantizationConfigWrapper quant_config; - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(my_quantizer.get()); -#ifdef __HIP_PLATFORM_AMD__ - at::Tensor ws = allocate_amax_workspace(unquantized_out_cu); - TensorWrapper tw = makeTransformerEngineTensor(ws); -#endif - - NVTE_SCOPED_GIL_RELEASE({ -#ifdef __HIP_PLATFORM_AMD__ - nvte_compute_amax_with_workspace(unquantized_out_cu.data(), out_cu.data(), - tw.data(), - at::cuda::getCurrentCUDAStream()); -#else - nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), - at::cuda::getCurrentCUDAStream()); -#endif - }); - // check if we need to do amax reudction (depending on model parallel configs) - if (my_quantizer_cs->with_amax_reduction) { - c10::intrusive_ptr process_group_ptr = - my_quantizer_cs->amax_reduction_group; - // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; - std::vector tensors = {amax_tensor_torch}; - // allreduce amax tensor - c10d::AllreduceOptions allreduce_opts; - allreduce_opts.reduceOp = c10d::ReduceOp::MAX; - process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); - } - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); - // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); - } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { - auto my_quantizer_bw = static_cast(my_quantizer.get()); - quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); - if (my_quantizer_bw->all_gather_usage) { - quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); - } + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); + my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); + } else { + my_quantizer->quantize(unquantized_out_cu, out_cu); } - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - }); } return {out, py::cast(mu), py::cast(rsigma)}; @@ -239,6 +203,52 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, return {py::cast(dx), py::cast(dgamma)}; } +std::vector rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor &x, + const at::Tensor &add, const at::Tensor &rsigma, + const at::Tensor &gamma, const int sm_margin, + const bool zero_centered_gamma) { + const auto &dz_ = dz.contiguous(); + const auto &x_ = x.contiguous(); + const auto &add_ = add.contiguous(); + const auto &rsigma_ = rsigma.contiguous(); + const auto &gamma_ = gamma.contiguous(); + + auto dx = at::empty_like(x_); + auto dgamma = at::empty_like(gamma_); + TensorWrapper workspace; + + auto dz_cu = makeTransformerEngineTensor(dz_); + auto x_cu = makeTransformerEngineTensor(x_); + auto add_cu = makeTransformerEngineTensor(add_); + auto rsigma_cu = makeTransformerEngineTensor(rsigma_); + auto gamma_cu = makeTransformerEngineTensor(gamma_); + auto dx_cu = makeTransformerEngineTensor(dx); + auto dgamma_cu = makeTransformerEngineTensor(dgamma); + + // This call populates tensors with the required config. + NVTE_SCOPED_GIL_RELEASE({ + nvte_rmsnorm_bwd_add(dz_cu.data(), x_cu.data(), add_cu.data(), rsigma_cu.data(), + gamma_cu.data(), dx_cu.data(), dgamma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); + }); + + // Alloc space for Tensors. + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + + // Actual call to bwd kernel. + NVTE_SCOPED_GIL_RELEASE({ + nvte_rmsnorm_bwd_add(dz_cu.data(), x_cu.data(), add_cu.data(), rsigma_cu.data(), + gamma_cu.data(), dx_cu.data(), dgamma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); + }); + + return {py::cast(dx), py::cast(dgamma)}; +} + std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, py::object out, py::handle quantizer, DType out_dtype, const int sm_margin, const bool zero_centered_gamma) { @@ -284,8 +294,15 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w TensorWrapper unquantized_out_cu; py::object unquantized_out; if (force_unfused_kernel) { - NoneQuantizer q{none}; - std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); + std::tie(unquantized_out_cu, unquantized_out) = + my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); + } else { + NoneQuantizer q{none}; + std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + } } TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; @@ -313,56 +330,13 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { - QuantizationConfigWrapper quant_config; - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(my_quantizer.get()); -#ifdef __HIP_PLATFORM_AMD__ - at::Tensor ws = allocate_amax_workspace(unquantized_out_cu); - TensorWrapper tw = makeTransformerEngineTensor(ws); -#endif - - NVTE_SCOPED_GIL_RELEASE({ -#ifdef __HIP_PLATFORM_AMD__ - nvte_compute_amax_with_workspace(unquantized_out_cu.data(), out_cu.data(), - tw.data(), - at::cuda::getCurrentCUDAStream()); -#else - nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), - at::cuda::getCurrentCUDAStream()); -#endif - }); - // check if we need to do amax reudction (depending on model parallel configs) - if (my_quantizer_cs->with_amax_reduction) { - c10::intrusive_ptr process_group_ptr = - my_quantizer_cs->amax_reduction_group; - // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; - std::vector tensors = {amax_tensor_torch}; - // allreduce amax tensor - c10d::AllreduceOptions allreduce_opts; - allreduce_opts.reduceOp = c10d::ReduceOp::MAX; - process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); - } - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); - // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); - } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { - auto my_quantizer_bw = static_cast(my_quantizer.get()); - quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); - if (my_quantizer_bw->all_gather_usage) { - quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); - } + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); + my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); + } else { + my_quantizer->quantize(unquantized_out_cu, out_cu); } - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - }); } return {out, py::none(), py::cast(rsigma)}; diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 14f5c83a4..55b1d179e 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -113,39 +113,55 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"), py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"), py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, - py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false); + py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false, + py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt); + /* GELU and variants*/ m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), py::arg("quantizer")); - m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), - py::arg("quantizer")); m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"), py::arg("quantizer")); + m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"), + py::arg("quantizer")); m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"), py::arg("quantizer")); + /* ReLU and variants */ + m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), + py::arg("quantizer")); m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"), py::arg("quantizer")); - m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), + m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"), py::arg("quantizer")); - m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"), + m.def("sreglu", transformer_engine::pytorch::sreglu, "Squared ReGLU activation", py::arg("input"), py::arg("quantizer")); - m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"), + /* SwiGLU and variants */ + m.def("silu", transformer_engine::pytorch::silu, "SiLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), py::arg("quantizer")); + /* Backward of GELU and variants */ m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + /* Backward of ReLU and variants */ + m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), + m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"), + m.def("dsreglu", transformer_engine::pytorch::dsreglu, "Backward of Squared ReGLU", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + /* Backward of SiLU and variants */ + m.def("dsilu", transformer_engine::pytorch::dsilu, "Backward of SiLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"), + m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + /* DBias + DAct fusions*/ m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize", @@ -203,6 +219,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm"); + m.def("rmsnorm_bwd_add", &transformer_engine::pytorch::rmsnorm_bwd_add, + "Fused backward of RMSNorm + add"); m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize, "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); m.def("split_quantize", &transformer_engine::pytorch::split_quantize, @@ -213,6 +231,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); + m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims, + "Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"), + py::call_guard()); m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend, "Get Fused Attention backend", py::call_guard()); m.def("compute_amax", &transformer_engine::pytorch::compute_amax, @@ -259,6 +280,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Fused Apply RoPE FWD", py::call_guard()); m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward, "Fused Apply RoPE BWD", py::call_guard()); + m.def("fused_qkv_rope_forward", &transformer_engine::pytorch::fused_qkv_rope_forward, + "Fused Apply QKV RoPE FWD", py::call_guard()); + m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward, + "Fused Apply QKV RoPE BWD", py::call_guard()); // fused router m.def("fused_topk_with_score_function_fwd", @@ -286,6 +311,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"), py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd"); + // Dropout + m.def("dropout_fwd", transformer_engine::pytorch::dropout_fwd, "Dropout forward with 8-bit RNG", + py::arg("input"), py::arg("dropout_probability"), py::arg("out") = std::nullopt); + m.def("dropout_bwd", transformer_engine::pytorch::dropout_bwd, "Dropout backward with 8-bit RNG", + py::arg("grad_output"), py::arg("mask"), py::arg("dropout_probability"), + py::arg("grad_input") = std::nullopt); + // Misc #ifndef USE_ROCM m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version, @@ -396,6 +428,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda, "Fused compute scale and scale_inv from amax", py::call_guard()); +#ifndef USE_ROCM + // Comm+GEMM Overlap + m.def("bulk_overlap_ag_with_external_gemm", + &transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm, + "Bulk overlap All-Gather with a GEMM operation launched by another communicator", + py::call_guard(), py::arg("allgather_communicator"), + py::arg("send_stream"), py::arg("recv_stream")); +#else + m.def("bulk_overlap_ag_with_external_gemm", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); +#endif + // Data structures py::class_(m, "FP8TensorMeta") .def(py::init<>()) diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index 0ede63d77..10a889b56 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -31,7 +31,7 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { at::Tensor ws = allocate_amax_workspace(te_input); TensorWrapper tw = makeTransformerEngineTensor(ws); nvte_compute_amax_with_workspace(te_input.data(), fake_te_output.data(), - tw.data(), + tw.data(), nullptr, at::cuda::getCurrentCUDAStream()); #else nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index d2f7107fe..7dfdf9954 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -18,31 +18,64 @@ namespace pytorch { at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output) { init_extension(); - const auto dim = input.dim(); - NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose."); - - if (input.dim() > 2) { - input = input.view({-1, input.size(dim - 1)}); + // Tensor dimensions + const auto shape = getTensorShape(input); + std::vector transpose_shape_int64; + if (shape.size() > 0) { + transpose_shape_int64.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + transpose_shape_int64.push_back(shape[i]); + } } + const size_t M = shape.size() > 0 ? product(shape) / shape.back() : 1; + const size_t N = shape.size() > 0 ? shape.back() : 1; - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - + // Output tensor at::Tensor out; if (output.has_value()) { out = *output; } else { - out = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + out = at::empty(transpose_shape_int64, opts); + } + + // Return immediately if tensor is empty + if (M == 0 || N == 0) { + return out; } - if (M == 0 || N == 0) return out; + // Compute transpose auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector{M, N}, otype); auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector{N, M}, otype); - nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return out; } +at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { + init_extension(); + + // Make sure input is contiguous + const auto &input = tensor.contiguous(); + + // Allocate output tensor if needed + if (!out) { + auto in_shape = getTensorShape(input); + NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")"); + std::vector out_shape_int64(in_shape.begin(), in_shape.end()); + out_shape_int64[0] = static_cast(in_shape[1]); + out_shape_int64[1] = static_cast(in_shape[0]); + auto opts = at::TensorOptions().dtype(input.dtype()).device(input.device()); + out = at::empty(out_shape_int64, opts); + } + + // Launch kernel + const TensorWrapper te_input = makeTransformerEngineTensor(input); + TensorWrapper te_output = makeTransformerEngineTensor(*out); + nvte_swap_first_dims(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + + return std::move(*out); +} + } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0ce1fc90e..37c13362c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -12,6 +14,27 @@ namespace transformer_engine::pytorch { +namespace { + +/*! @brief Transposed tensor shape + * + * The tensor is interpreted as a 2D matrix by flattening all but the + * last dimension, and then transposed. + */ +template +std::vector make_transpose_shape(const std::vector& shape) { + std::vector ret; + if (shape.size() > 0) { + ret.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + ret.push_back(shape[i]); + } + } + return ret; +} + +} // namespace + constexpr size_t MXFP8_BLOCK_SIZE = 32; Quantizer::Quantizer(const py::handle& quantizer) { @@ -37,24 +60,36 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti this->dtype = type; } -std::pair NoneQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { - at::TensorOptions opts; - opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA); - std::vector torch_shape; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); - } - at::Tensor ret; - if (rowwise_data.has_value()) { - ret = std::move(*rowwise_data); - } else { - ret = at::empty(torch_shape, opts); - } +std::pair NoneQuantizer::create_tensor(const std::vector& shape, + DType dtype) const { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); + return create_tensor(shape, dtype, at::empty(shape_int64, opts)); +} + +std::pair NoneQuantizer::create_tensor(const std::vector& shape, + DType dtype, + at::Tensor data) const { + TensorWrapper out_cpp; + out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); + set_quantization_params(&out_cpp); + return {std::move(out_cpp), py::cast(data)}; +} + +std::pair NoneQuantizer::convert_and_update_tensor( + py::object tensor) const { + auto tensor_pyt = tensor.cast(); + TensorWrapper out_cpp; + out_cpp.set_rowwise_data(tensor_pyt.data_ptr(), + GetTransformerEngineDType(tensor_pyt.scalar_type()), + getTensorShape(tensor_pyt)); + set_quantization_params(&out_cpp); + return {std::move(out_cpp), std::move(tensor)}; +} - TensorWrapper tensor; - tensor.set_rowwise_data(ret.data_ptr(), dtype, shape); - return {std::move(tensor), py::cast(ret)}; +void NoneQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + NVTE_ERROR("NoneQuantizer does not support quantization"); } void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { @@ -63,81 +98,183 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); +} - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); +std::pair Float8Quantizer::create_tensor( + const std::vector& shape, DType dtype) const { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + at::Tensor scale_inv = at::empty(std::vector{1}, opts); + return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); } std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype, std::optional data, + std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; - std::vector rowwise_torch_shape; - std::vector columnwise_torch_shape; - if (!shape.empty()) { - columnwise_torch_shape.emplace_back(static_cast(shape.back())); - } - for (size_t i = 0; i < shape.size(); ++i) { - if (i < shape.size() - 1) { - columnwise_torch_shape.emplace_back(static_cast(shape[i])); - } - rowwise_torch_shape.emplace_back(static_cast(shape[i])); + // Initialize data tensor + const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + if (with_data && !data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data = at::empty(shape_int64, opts); + } else if (!with_data && data) { + data.reset(); } - at::TensorOptions opts; - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - at::Tensor data; - if (rowwise_usage) { - if (rowwise_data.has_value()) { - data = std::move(*rowwise_data); - } else { - data = at::empty(rowwise_torch_shape, opts); - } + py::object data_py = with_data ? py::cast(*data) : py::none(); + + // Initialize transpose tensor + const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + if (with_transpose && !transpose) { + const auto transpose_shape = make_transpose_shape(shape); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose = at::empty(transpose_shape, opts); + } else if (!with_transpose && transpose) { + transpose.reset(); } - const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); - at::Tensor columnwise_data; - bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); - if (create_transpose) { - columnwise_data = at::empty(columnwise_torch_shape, opts); + py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); + + // Initialize scale-inverse tensor + if (!scale_inv) { + scale_inv = at::reciprocal(scale); } - const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); - opts = opts.dtype(torch::kFloat32); - // TODO: Replace with an empty tensor. - at::Tensor scale_inv = at::reciprocal(scale); - py::object ret; + + // Construct Python FP8 tensor + py::object out_py; if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); - ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), - "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + const std::vector shape_int64(shape.begin(), shape.end()); + out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } - TensorWrapper tensor(this->get_scaling_mode()); - if (rowwise_usage) { - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + + // Construct C++ FP8 tensor + TensorWrapper out_cpp(this->get_scaling_mode()); + if (with_data) { + out_cpp.set_rowwise_data(data->data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, std::vector{1}); + } + if (with_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + out_cpp.set_columnwise_data(transpose->data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair Float8Quantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); + + // Expected buffers + const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); + + // Extract buffers from Python tensor + auto data_py = tensor.attr("_data"); + auto transpose_py = tensor.attr("_transpose"); + const bool has_data = !data_py.is_none(); + const bool has_transpose = !transpose_py.is_none(); + NVTE_CHECK(has_data || has_transpose, "Float8Tensor has no data."); + std::optional data_tensor, transpose_tensor; + if (has_data) { + data_tensor = data_py.cast(); + } + if (has_transpose) { + transpose_tensor = transpose_py.cast(); } - if (create_transpose) { - std::vector transposed_shape; - for (auto s : columnwise_torch_shape) { - transposed_shape.emplace_back(static_cast(s)); + at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); + + // Tensor dimensions + std::vector shape; + if (has_transpose) { + const auto transpose_shape = getTensorShape(*transpose_tensor); + if (transpose_shape.size() > 0) { + for (size_t i = 1; i < transpose_shape.size(); ++i) { + shape.push_back(transpose_shape[i]); + } + shape.push_back(transpose_shape.front()); } - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); - tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + if (has_data) { + auto expected_shape = getTensorShape(*data_tensor); + NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, + ") and transpose (shape=", transpose_shape, ") do not match"); + } + } else { // Already checked has_data == true + shape = getTensorShape(*data_tensor); } - this->set_quantization_params(&tensor); - return {std::move(tensor), std::move(ret)}; + + // Coerce data tensor + if (has_data && !need_data) { + data_tensor.reset(); + data_py = py::none(); + tensor.attr("_data") = data_py; + } else if (!has_data && need_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data_tensor = at::empty(shape_int64, opts); + data_py = py::cast(data_tensor); + tensor.attr("_data") = data_py; + } + + // Coerce transpose tensor + if (has_transpose && !need_transpose) { + transpose_tensor.reset(); + transpose_py = py::none(); + tensor.attr("_transpose") = transpose_py; + } else if (!has_transpose && need_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose_tensor = at::empty(transpose_shape, opts); + transpose_py = py::cast(transpose_tensor); + tensor.attr("_transpose") = transpose_py; + } + tensor.attr("_transpose_invalid") = !need_transpose; + + // Coerce other attrs + tensor.attr("_fp8_dtype") = dtype; + + // Construct C++ FP8 tensor + TensorWrapper out_cpp; + if (data_tensor) { + out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (transpose_tensor) { + const auto transpose_shape = make_transpose_shape(shape); + out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void Float8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + if (input.numel() == 0) { + return; + } + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); + }); } Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer) @@ -173,85 +310,235 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); - // quantize output and its transpose - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); } std::pair Float8CurrentScalingQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype) const { using namespace pybind11::literals; - std::vector rowwise_torch_shape; - std::vector columnwise_torch_shape; - std::vector scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv - if (!shape.empty()) { - columnwise_torch_shape.emplace_back(static_cast(shape.back())); + // Initialize data tensor + at::Tensor data_tensor; + const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + if (with_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data_tensor = at::empty(shape_int64, opts); } - for (size_t i = 0; i < shape.size(); ++i) { - if (i < shape.size() - 1) { - columnwise_torch_shape.emplace_back(static_cast(shape[i])); - } - rowwise_torch_shape.emplace_back(static_cast(shape[i])); - } - at::TensorOptions opts; - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - at::Tensor data; - if (rowwise_usage) { - if (rowwise_data.has_value()) { - data = std::move(*rowwise_data); - } else { - data = at::empty(rowwise_torch_shape, opts); - } - } - const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); - at::Tensor columnwise_data; - bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); - if (create_transpose) { - columnwise_data = at::empty(columnwise_torch_shape, opts); + + // Initialize transpose tensor + at::Tensor transpose_tensor; + const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + if (with_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose_tensor = at::empty(transpose_shape, opts); } - const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); - // In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. - at::Tensor scale_inv = at::reciprocal(scale); + // Initialize scale-inverse tensor + at::Tensor scale_inv_tensor; + { + const std::vector scale_inv_shape = {1}; + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + scale_inv_tensor = at::empty(scale_inv_shape, opts); + } - py::object ret; + // Construct Python FP8 tensor + py::object out_py; + py::object data_py = with_data ? py::cast(data_tensor) : py::none(); + py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); - ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), - "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + const std::vector shape_int64(shape.begin(), shape.end()); + out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } - TensorWrapper tensor(this->get_scaling_mode()); - if (rowwise_usage) { - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + + // Construct C++ FP8 tensor + TensorWrapper out_cpp(this->get_scaling_mode()); + if (with_data) { + out_cpp.set_rowwise_data(data_tensor.data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (with_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + out_cpp.set_columnwise_data(transpose_tensor.data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair Float8CurrentScalingQuantizer::create_hp_tensor_with_amax( + const std::vector& shape, DType dtype) { + amax.zero_(); + auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); + out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair Float8CurrentScalingQuantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), + "Float8CurrentScalingQuantizer must output to Float8Tensor."); + + // Expected buffers + const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); + + // Extract buffers from Python tensor + auto data_py = tensor.attr("_data"); + auto transpose_py = tensor.attr("_transpose"); + const bool has_data = !data_py.is_none(); + const bool has_transpose = !transpose_py.is_none(); + NVTE_CHECK(has_data || has_transpose, "Tensor has no data."); + std::optional data_tensor, transpose_tensor; + if (has_data) { + data_tensor = data_py.cast(); } - if (create_transpose) { - std::vector transposed_shape; - for (auto s : columnwise_torch_shape) { - transposed_shape.emplace_back(static_cast(s)); + if (has_transpose) { + transpose_tensor = transpose_py.cast(); + } + at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); + + // Tensor dimensions + std::vector shape; + if (has_transpose) { + const auto transpose_shape = getTensorShape(*transpose_tensor); + if (transpose_shape.size() > 0) { + for (size_t i = 1; i < transpose_shape.size(); ++i) { + shape.push_back(transpose_shape[i]); + } + shape.push_back(transpose_shape.front()); + } + if (has_data) { + auto expected_shape = getTensorShape(*data_tensor); + NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, + ") and transpose (shape=", transpose_shape, ") do not match"); } - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); - tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } else { // Already checked has_data == true + shape = getTensorShape(*data_tensor); } - this->set_quantization_params(&tensor); - return {std::move(tensor), std::move(ret)}; + // Coerce data tensor in Python tensor + if (has_data && !need_data) { + data_tensor.reset(); + data_py = py::none(); + tensor.attr("_data") = data_py; + } else if (!has_data && need_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data_tensor = at::empty(shape_int64, opts); + data_py = py::cast(data_tensor); + tensor.attr("_data") = data_py; + } + + // Coerce transpose tensor + if (has_transpose && !need_transpose) { + transpose_tensor.reset(); + transpose_py = py::none(); + tensor.attr("_transpose") = transpose_py; + } else if (!has_transpose && need_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose_tensor = at::empty(transpose_shape, opts); + transpose_py = py::cast(transpose_tensor); + tensor.attr("_transpose") = transpose_py; + } + tensor.attr("_transpose_invalid") = !need_transpose; + + // Coerce other attrs + tensor.attr("_fp8_dtype") = dtype; + + // Construct C++ FP8 tensor + TensorWrapper out_cpp; + if (data_tensor) { + out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (transpose_tensor) { + const auto transpose_shape = make_transpose_shape(shape); + out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag, + bool compute_amax) { + auto stream = at::cuda::getCurrentCUDAStream(); + + // Nothing to be done if input is empty + if (input.numel() == 0) { + return; + } + + // Quantization configs + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(amax_epsilon); + + // Compute amax + if (compute_amax) { +#ifdef __HIP_PLATFORM_AMD__ + at::Tensor ws = allocate_amax_workspace(input); + TensorWrapper tw = makeTransformerEngineTensor(ws); + NVTE_SCOPED_GIL_RELEASE({ + nvte_compute_amax_with_workspace(input.data(), out.data(), tw.data(), quant_config, stream); + }); +#else + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); +#endif + } + + // Perform amax reduction if needed + if (with_amax_reduction) { + // allreduce amax tensor + c10d::AllreduceOptions opts; + opts.reduceOp = c10d::ReduceOp::MAX; + std::vector tensors = {amax}; + NVTE_SCOPED_GIL_RELEASE({ amax_reduction_group->allreduce(tensors, opts)->wait(); }); + } + + // Compute scaling factor + NVTE_SCOPED_GIL_RELEASE({ nvte_compute_scale_from_amax(out.data(), quant_config, stream); }); + + // Cast to FP8 + out.set_amax(nullptr, DType::kFloat32, out.defaultShape); // Avoid atomic amax updates + NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); +} + +void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + this->quantize_impl(input, out, noop_flag, true); +} + +void Float8CurrentScalingQuantizer::quantize_with_amax( + TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag) { + NVTE_CHECK(input.get_amax().data_ptr == amax.data_ptr(), + "Input does not use the appropriate amax tensor"); + input.set_amax(nullptr, DType::kFloat32, input.defaultShape); + this->quantize_impl(input, out, noop_flag, false); } Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { @@ -264,23 +551,10 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti this->all_gather_usage = quantizer.attr("all_gather_usage").cast(); } -void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { - // Change the rowwise and columnwise_data to the configured dtype. - // May be a switch between E5M2 and E4M3. - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); -} +void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair Float8BlockQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype) const { using namespace pybind11::literals; std::vector torch_shape; for (auto s : shape) { @@ -299,11 +573,7 @@ std::pair Float8BlockQuantizer::create_tensor( : Float8BlockScaleTensorFormat::GEMM_READY); if (rowwise_usage) { - if (rowwise_data.has_value()) { - data_rowwise = std::move(*rowwise_data); - } else { - data_rowwise = at::empty(torch_shape, opts); - } + data_rowwise = at::empty(torch_shape, opts); auto scale_shape = get_scale_shape(shape, false); size_t sinv0 = scale_shape[0]; size_t sinv1 = scale_shape[1]; @@ -373,6 +643,177 @@ std::pair Float8BlockQuantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } +std::pair Float8BlockQuantizer::convert_and_update_tensor( + py::object tensor) const { + const DType dtype = tensor.attr("_fp8_dtype").cast(); + bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); + + // Extract buffers from Python tensor + auto get_tensor = [&tensor](const char* name) -> std::optional { + auto attr_py = tensor.attr(name); + if (attr_py.is_none()) { + return std::nullopt; + } + return attr_py.cast(); + }; + auto rowwise_data = get_tensor("_rowwise_data"); + auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); + auto columnwise_data = get_tensor("_columnwise_data"); + auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); + NVTE_CHECK(rowwise_data || columnwise_data, "FP8BlockwiseTensor has no data."); + + // Tensor options and dimensions + at::TensorOptions opts; + at::TensorOptions scale_opts; + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); + + auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector { + if (!columnwise_data) { + return std::vector(); + } + if (all_gather_usage) { + return getTensorShape(*columnwise_data); + } + std::vector shape = getTensorShape(*columnwise_data); + std::vector shape_transposed(shape.size()); + for (size_t i = 0; i + 1 < shape.size(); ++i) { + shape_transposed[i] = shape[i + 1]; + } + if (shape.size() > 0) { + shape_transposed[shape.size() - 1] = shape[0]; + } + return shape_transposed; + }; + std::vector shape; + if (rowwise_data) { + shape = getTensorShape(*rowwise_data); + if (columnwise_data) { + auto expected_shape = get_columnwise_shape(all_gather_usage); + NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, + ") and column-wise data (shape=", expected_shape, ") do not match"); + } + } else { + shape = get_columnwise_shape(all_gather_usage); + } + std::vector torch_shape; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + } + + // Coerce row-wise data + if (rowwise_usage) { + if (!rowwise_data) { + rowwise_data = at::empty(torch_shape, opts); + tensor.attr("_rowwise_data") = *rowwise_data; + } + if (!rowwise_scale_inv) { + auto scale_shape = get_scale_shape(shape, false); + size_t sinv0 = scale_shape[0]; + size_t sinv1 = scale_shape[1]; + rowwise_scale_inv = + at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); + tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; + } + } else { // rowwise_usage == false + if (rowwise_data) { + rowwise_data.reset(); + tensor.attr("_rowwise_data") = py::none(); + } + if (rowwise_scale_inv) { + rowwise_scale_inv.reset(); + tensor.attr("_rowwise_scale_inv") = py::none(); + } + } + + // Coerce column-wise data + if (columnwise_usage) { + std::vector columnwise_shape; + std::vector torch_columnwise_shape; + if (torch_shape.size() > 0) { + if (!all_gather_usage) { + torch_columnwise_shape.reserve(torch_shape.size()); + columnwise_shape.reserve(shape.size()); + torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); + columnwise_shape.push_back(shape[shape.size() - 1]); + for (size_t i = 0; i < torch_shape.size() - 1; ++i) { + torch_columnwise_shape.push_back(torch_shape[i]); + columnwise_shape.push_back(shape[i]); + } + } else { + // assert we are doing 1D scaling + NVTE_CHECK(block_scaling_dim == 1, + "Compact columnwise format is not supported for 128x128 2D block scaling."); + torch_columnwise_shape = torch_shape; + columnwise_shape = shape; + } + } + if (!columnwise_data) { + columnwise_data = at::empty(torch_columnwise_shape, opts); + tensor.attr("_columnwise_data") = *columnwise_data; + } + if (!columnwise_scale_inv) { + auto scale_shape = get_scale_shape(shape, true); + size_t sinv0 = scale_shape[0]; + size_t sinv1 = scale_shape[1]; + columnwise_scale_inv = + at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); + tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; + } + } else { // columnwise_usage == false + if (columnwise_data) { + columnwise_data.reset(); + tensor.attr("_columnwise_data") = py::none(); + } + if (columnwise_scale_inv) { + columnwise_scale_inv.reset(); + tensor.attr("_columnwise_scale_inv") = py::none(); + } + } + + auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); + + if (rowwise_usage) { + const at::Tensor& data_rowwise = tensor.attr("_rowwise_data").cast(); + const at::Tensor& scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); + void* scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); + const auto& rowwise_shape = getTensorShape(data_rowwise); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape); + const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); + ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); + } + if (columnwise_usage) { + const at::Tensor& data_colwise = tensor.attr("_columnwise_data").cast(); + const at::Tensor& scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); + void* scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); + const auto& shape = getTensorShape(data_colwise); + ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); + const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); + ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); + } + set_quantization_params(&ret); + return {std::move(ret), std::move(tensor)}; +} + +void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + if (input.numel() == 0) { + return; + } + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(amax_epsilon); + if (all_gather_usage) { + quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); + }); +} + std::vector Float8BlockQuantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { size_t numel = 1; @@ -452,84 +893,206 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize this->dtype = quantizer.attr("dtype").cast(); } -void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); -} +void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} -std::pair MXFP8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { +std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, + DType dtype) const { using namespace pybind11::literals; - std::vector torch_shape; - size_t numel = 1; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); - numel *= s; - } - - TensorWrapper tensor(NVTE_MXFP8_1D_SCALING); - at::TensorOptions opts; - at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, - columnwise_scale_inv; // TODO(pgadzinski) - change - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - at::Tensor data; - if (rowwise_usage) { - if (rowwise_data.has_value()) { - data = std::move(*rowwise_data); - } else { - data = at::empty(torch_shape, opts); + // Tensor dimensions + const std::vector shape_int64(shape.begin(), shape.end()); + size_t flat_first_dim = 1; + if (shape.size() > 0) { + for (size_t i = 0; i < shape.size() - 1; ++i) { + flat_first_dim *= shape[i]; } - auto scale_shape = get_scale_shape(shape, false); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; - rowwise_scale_inv = at::zeros({static_cast(sinv0), static_cast(sinv1)}, opts); - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv( - rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, - std::vector{static_cast(sinv0), static_cast(sinv1)}); } + const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, + "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + " (got shape=", shape, ")"); + const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); + const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); + // Allocate tensors + at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor; + at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor; + const auto uint8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + if (rowwise_usage) { + const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), + rowwise_scale_inv_shape.end()); + rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); + rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); + } if (columnwise_usage) { - auto scale_shape = get_scale_shape(shape, true); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; - columnwise_data = at::empty(torch_shape, opts); - columnwise_scale_inv = - at::zeros({static_cast(sinv0), static_cast(sinv1)}, opts); - - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); - tensor.set_columnwise_scale_inv( - columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, - std::vector{static_cast(sinv0), static_cast(sinv1)}); + const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), + columnwise_scale_inv_shape.end()); + columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); + columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); } - this->set_quantization_params(&tensor); - py::object ret; + // Convert tensors to Python + auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object { + return need_cast ? py::cast(tensor) : py::none(); + }; + auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); + auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage); + auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); + auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); + + // Construct Python MXFP8 tensor + py::object out_py; if (internal) { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); - ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, - "rowwise_scale_inv"_a = rowwise_scale_inv, - "columnwise_scale_inv"_a = columnwise_scale_inv, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, + "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); } else { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); - ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, - "rowwise_scale_inv"_a = rowwise_scale_inv, - "columnwise_scale_inv"_a = columnwise_scale_inv, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = rowwise_data_py, + "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); } - return {std::move(tensor), std::move(ret)}; + // Construct C++ MXFP8 tensor + TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, + rowwise_scale_inv_shape); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), this->dtype, shape); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, + columnwise_scale_inv_shape); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair MXFP8Quantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); + + // Extract buffers from Python tensor + auto get_tensor = [&tensor](const char* name) -> std::optional { + auto attr_py = tensor.attr(name); + if (attr_py.is_none()) { + return std::nullopt; + } + return attr_py.cast(); + }; + auto rowwise_data = get_tensor("_rowwise_data"); + auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); + auto columnwise_data = get_tensor("_columnwise_data"); + auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); + NVTE_CHECK(rowwise_data || columnwise_data, "MXFP8Tensor has no data."); + + // Tensor dimensions + std::vector shape; + if (columnwise_data) { + shape = getTensorShape(*columnwise_data); + if (rowwise_data) { + auto expected_shape = getTensorShape(*rowwise_data); + NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, + ") and column-wise data (shape=", shape, ") do not match"); + } + } else { // Already checked columnwise_data_tensor == true + shape = getTensorShape(*rowwise_data); + } + + // Coerce row-wise data + if (rowwise_usage) { + if (!rowwise_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_data = at::empty(shape_int64, opts); + tensor.attr("_rowwise_data") = *rowwise_data; + } + if (!rowwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, false); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; + } + } else { // rowwise_usage == false + if (rowwise_data) { + rowwise_data.reset(); + tensor.attr("_rowwise_data") = py::none(); + } + if (rowwise_scale_inv) { + rowwise_scale_inv.reset(); + tensor.attr("_rowwise_scale_inv") = py::none(); + } + } + + // Coerce column-wise data + if (columnwise_usage) { + if (!columnwise_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + columnwise_data = at::empty(shape_int64, opts); + tensor.attr("_columnwise_data") = *columnwise_data; + } + if (!columnwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, true); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; + } + } else { // columnwise_usage == false + if (columnwise_data) { + columnwise_data.reset(); + tensor.attr("_columnwise_data") = py::none(); + } + if (columnwise_scale_inv) { + columnwise_scale_inv.reset(); + tensor.attr("_columnwise_scale_inv") = py::none(); + } + } + + // Coerce other attrs + tensor.attr("_fp8_dtype") = dtype; + + // Construct C++ MXFP8 tensor + TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), dtype, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), dtype, shape); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*columnwise_scale_inv)); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + if (input.numel() == 0) { + return; + } + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); + }); } std::vector MXFP8Quantizer::get_scale_shape(const std::vector& shape, diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index b58573419..44b636930 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -80,4 +80,99 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap return swizzled_scale_inv; } +std::optional multi_tensor_swizzle_scaling_factors( + std::vector& tensors, bool rowwise) { + using namespace transformer_engine::pytorch; + + if (tensors.empty()) { + return std::nullopt; + } + + bool all_same_scaling_mode = std::all_of( + tensors.cbegin(), tensors.cend(), [&tensors](const transformer_engine::TensorWrapper& val) { + return val.scaling_mode() == tensors.front().scaling_mode(); + }); + NVTE_CHECK(all_same_scaling_mode, "Scaling mode of the input tensors must be the same."); + + if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) { + NVTE_ERROR("Invalid scaling mode for swizzle."); + } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) { + return std::nullopt; + } + + std::vector wrappers; + std::vector input_tensors, output_tensors; + + // Collect scale_inv shapes and calculate buffer size and offsets for scale_invs + std::vector> scale_inv_shapes; + std::vector scale_inv_dptrs; + size_t buffer_size = 0; + std::vector scale_inv_offsets; + constexpr size_t scale_elem_size = 1; + for (auto& tensor : tensors) { + NVTEBasicTensor scale_inv; + if (rowwise) { + scale_inv = tensor.get_rowwise_scale_inv(); + } else { + scale_inv = tensor.get_columnwise_scale_inv(); + } + auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_inv_offsets.push_back(buffer_size); + buffer_size += product(scale_inv_shape) * scale_elem_size; + scale_inv_shapes.emplace_back(scale_inv_shape); + scale_inv_dptrs.push_back(scale_inv.data_ptr); + } + + // Allocate full buffer + auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); + + for (size_t i = 0; i < tensors.size(); ++i) { + auto& tensor = tensors[i]; + void* scale_inv_dptr = scale_inv_dptrs[i]; + void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]); + auto input_shape = nvte_shape_to_vector(tensor.shape()); + + // Reconstruct input only to avoid swizzling both directions if not needed. + // Use any 8 bit type, it's irrelevant. + transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + if (rowwise) { + input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + scale_inv_shapes[i]); + output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, + input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, + transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + // Set the swizzled scaling factor to the original tensor. + tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + scale_inv_shapes[i]); + } else { + input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, + input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + scale_inv_shapes[i]); + output_cu.set_columnwise_data(tensor.columnwise_dptr(), + transformer_engine::DType::kFloat8E4M3, input_shape); + output_cu.set_columnwise_scale_inv( + swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + // Set the swizzled scaling factor to the original tensor. + tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, + transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + } + + input_tensors.emplace_back(input_cu.data()); + output_tensors.emplace_back(output_cu.data()); + wrappers.emplace_back(std::move(input_cu)); + wrappers.emplace_back(std::move(output_cu)); + } + + // Launch kernel + nvte_multi_tensor_swizzle_scaling_factors(input_tensors.data(), output_tensors.data(), + input_tensors.size(), at::cuda::getCurrentCUDAStream()); + + return buffer; +} + #endif //!USE_ROCM diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 97f22ae18..621cc1db8 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -17,12 +17,19 @@ #include "transformer_engine/transformer_engine.h" -/* Swizzle the scaling factor of the input tensor. +/*! \brief Swizzle the scaling factor of the input tensor. * * The returned swizzled scaling factor tensor should be kept alive during the GEMM. */ std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper &input, - bool trans); + bool rowwise); + +/*! \brief Swizzle the scaling factor of the input tensors. + * + * The returned swizzled scaling factor tensors should be kept alive during the GEMMs. + */ +std::optional multi_tensor_swizzle_scaling_factors( + std::vector &inputs, bool rowwise); #endif //!USE_ROCM diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index c3ec514a5..e809528da 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -985,6 +985,15 @@ def _all_gather_fp8( return out, handle +def _get_quantizer_format(quantizer: Quantizer) -> Optional[bool]: + """Get quantizer format.""" + if isinstance(quantizer, DebugQuantizer): + quantizer = quantizer.parent_quantizer + if isinstance(quantizer, Float8BlockQuantizer): + return quantizer.all_gather_usage + return None + + def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None: """Make quantizer compact""" _quantizer = quantizer @@ -1133,6 +1142,10 @@ def _all_gather_fp8_blockwise( "Dequantizing and requantizing to Float8BlockwiseQTensor." ) inp = quantizer(inp.dequantize()) + + # Construct Float8BlockwiseQTensor output tensor + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + quantizer.all_gather_usage = orig_all_gather_usage # Begin to do network communication, need to make sure compact format @@ -1142,9 +1155,6 @@ def _all_gather_fp8_blockwise( f"but found data_format={inp._data_format}" ) - # Construct Float8BlockwiseQTensor output tensor - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - # Coalesce NCCL collectives with torch.distributed._coalescing_manager( group=process_group, @@ -1220,14 +1230,12 @@ def _all_gather_mxfp8( if inp._rowwise_data is not None: in_shape = inp._rowwise_data.size() device = inp._rowwise_data.device - dtype = inp._rowwise_data.dtype elif inp._columnwise_data is not None: in_shape = inp._columnwise_data.size() device = inp._columnwise_data.device - dtype = inp._columnwise_data.dtype else: raise ValueError("Got MXFP8 input tensor without any data") - dtype = torch.bfloat16 + dtype = torch.bfloat16 # Guess high-precision dtype. else: raise ValueError( "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, " @@ -1347,6 +1355,44 @@ def gather_along_first_dim( inp = quantizer(inp) return inp, None + # Debug case - call gather_along_first_dim on each tensor + if isinstance(inp, DebugQuantizedTensor): + out_obj = DebugQuantizedTensor( + rowwise_gemm_tensor=inp.rowwise_gemm_tensor, + columnwise_gemm_tensor=inp.columnwise_gemm_tensor, + quantizer=inp.quantizer, + layer_name=inp._layer_name, + tensor_name=inp._tensor_name, + ) + rowwise = inp.get_tensor(False) + columnwise = inp.get_tensor(True) + # shapes + final_quantizer = ( + None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer + ) + rowwise_total = None + if rowwise is not None: + rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[ + 0 + ] + out_obj.rowwise_gemm_tensor = rowwise_total + if rowwise is not columnwise: + final_quantizer_columnwise = ( + None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer + ) + columnwise_total = None + if columnwise is not None: + columnwise_total, _ = gather_along_first_dim( + columnwise, process_group, False, final_quantizer_columnwise + ) + out_obj.columnwise_gemm_tensor = columnwise_total + else: + # Sometimes the same object is used both for rowwise and columnwise gemms, + # and we want to avoid double all-gathers. + out_obj.columnwise_gemm_tensor = out_obj.rowwise_gemm_tensor + + return out_obj, None + # Output tensor dims out_shape = list(inp.size()) out_shape[0] *= world_size @@ -1384,34 +1430,6 @@ def gather_along_first_dim( out_shape=out_shape, ) - # Debug case - call gather_along_first_dim on each tensor - if isinstance(inp, DebugQuantizedTensor): - out_obj = inp - rowwise = inp.get_tensor(False) - columnwise = inp.get_tensor(True) - final_quantizer = ( - None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer - ) - # Temporary fix for TP communication of Float8BlockwiseQTensorBase - if isinstance(rowwise, Float8BlockwiseQTensorBase): - rowwise = inp._original_tensor - rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0] - out_obj.rowwise_gemm_tensor = rowwise_total - if rowwise is not columnwise: - final_quantizer_columnwise = ( - None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer - ) - # Temporary fix for TP communication of Float8BlockwiseQTensorBase - if isinstance(columnwise, Float8BlockwiseQTensorBase): - columnwise = inp._original_tensor - columnwise_total, _ = gather_along_first_dim( - columnwise, process_group, False, final_quantizer_columnwise - ) - out_obj.columnwise_gemm_tensor = columnwise_total - else: - out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor - return out_obj, None - # High-precision communication for quantized tensors if quantizer is not None: warnings.warn( @@ -1422,6 +1440,7 @@ def gather_along_first_dim( inp = inp.dequantize() # Falling back to high-precision all-gather for Float8BlockQuantizer # means that it should directly output GEMM_READY format + compact = _get_quantizer_format(quantizer) _set_quantizer_format(quantizer, compact=False) out = torch.empty( out_shape, @@ -1431,6 +1450,7 @@ def gather_along_first_dim( ) torch.distributed.all_gather_into_tensor(out, inp, group=process_group) out = quantizer(out) + _set_quantizer_format(quantizer, compact=compact) return out, None # Dequantize quantized tensor if not supported diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 55280aa82..15cb88b00 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -80,6 +80,19 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." +def check_recipe_support(recipe: Recipe) -> None: + """Check if the given recipe is supported.""" + recipe_supported = True + unsupported_reason = "" + if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)): + recipe_supported, unsupported_reason = check_fp8_support() + elif isinstance(recipe, Float8BlockScaling): + recipe_supported, unsupported_reason = check_fp8_block_scaling_support() + elif isinstance(recipe, MXFP8BlockScaling): + recipe_supported, unsupported_reason = check_mxfp8_support() + assert recipe_supported, unsupported_reason + + def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" if IS_HIP_EXTENSION: @@ -90,11 +103,10 @@ def get_default_fp8_recipe() -> Recipe: return MXFP8BlockScaling() return DelayedScaling() if check_mxfp8_support()[0]: - # This is a temporary restriction until MXFP8 is supported for all - # gemm layouts. - if get_device_compute_capability() >= (12, 0): - return Float8BlockScaling() return MXFP8BlockScaling() + if get_device_compute_capability() >= (12, 0): + # This is a temporary restriction until MXFP8 is supported for all gemm layouts. + return Float8CurrentScaling() return DelayedScaling() @@ -672,6 +684,8 @@ def fp8_autocast( distributed group over which amaxes for the fp8 tensors are reduced at the end of each training step. """ + if enabled: + check_recipe_support(fp8_recipe) fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() FP8GlobalStateManager.fp8_autocast_enter( enabled=enabled, diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 0c7e3fe19..4984013ab 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -6,6 +6,8 @@ """Functions for CUDA Graphs support in FP8""" from collections.abc import Iterable +import contextlib +import gc from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import torch @@ -23,6 +25,8 @@ from .distributed import get_all_rng_states, graph_safe_rng_available from .module.base import TransformerEngineBaseModule from .ops.op import BasicOperation +from .ops import Sequential +from .ops.fuser import OperationFuser from .utils import make_weak_ref __all__ = ["make_graphed_callables"] @@ -46,7 +50,7 @@ def set_capture_end() -> None: _IS_GRAPH_CAPTURING = False -def is_graph_capturing() -> None: +def is_graph_capturing() -> bool: """Return whether within `make_graphed_callables`.""" return _IS_GRAPH_CAPTURING @@ -58,6 +62,25 @@ def graph_pool_handle(): return _graph_pool_handle() +@contextlib.contextmanager +def _graph_context_wrapper(*args, **kwargs): + """Wrapper around `torch.cuda.graph`. + + This wrapper is a temporary workaround for a PyTorch bug: + automatic garbage collection can destroy a graph while another + graph is being captured, resulting in a CUDA error. See + https://github.com/pytorch/pytorch/pull/161037. + + """ + gc_is_enabled = gc.isenabled() + if gc_is_enabled: + gc.disable() + with torch.cuda.graph(*args, **kwargs): + yield + if gc_is_enabled: + gc.enable() + + def _make_graphed_callables( callables: SingleOrTuple[Callable], sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]], @@ -179,24 +202,17 @@ def _make_graphed_callables( assert isinstance( sample_args, list ), "sample_args must be a list for _reuse_graph_input_output_buffers." - len_args = len(sample_args[0]) - for i, arg in enumerate(sample_args): - assert len_args == len( - arg - ), "Arguments must have same length and shape for `_reuse_graph_input_output_buffers`." - len_kwargs = len(sample_kwargs[0]) - assert isinstance( - sample_kwargs, list - ), "sample_kwargs must be a list for _reuse_graph_input_output_buffers." - for i, kwarg in enumerate(sample_kwargs): - assert len_kwargs == len(kwarg), ( - "Keyword arguments must have same length and shape for" - " `_reuse_graph_input_output_buffers`." - ) # Reorganize args and kwargs for input tensor reuse. + # fwd_sample_qs is keyed by model chunk index. The value is a queue of tuples. + # Each tuple contains the sample key signature and its fwd_idx. When we finish a backward + # chunk, we pop the corresponding fwd_idx and push to the consumed_sample_q. + # consumed_sample_q is keyed by the sample key signature. The value is a queue of the + # fwd_idx whose backward has been called so that we can reuse the same static buffers. + # In this way, we can reuse the same static input buffers for the non-overlapping samples + # with the same input signature. fwd_sample_qs = {} - consumed_sample_q = [] + consumed_sample_q = {} fwd_idx = [0] * num_model_chunks for c_id in _order: m_chunk = abs(c_id) - 1 @@ -208,10 +224,21 @@ def _make_graphed_callables( fwd_sample_idx = [ sample_start_idx + i for i in range(_num_layers_per_chunk[m_chunk]) ] - fwd_sample_qs[m_chunk] = fwd_sample_qs.get(m_chunk, []) + fwd_sample_idx + if m_chunk not in fwd_sample_qs: + fwd_sample_qs[m_chunk] = [] for per_callable_fwd_idx in fwd_sample_idx: - if consumed_sample_q: - reuse_fwd_idx = consumed_sample_q.pop(0) + sample_args_keys = tuple( + (t.shape, t.dtype, t.layout) for t in sample_args[per_callable_fwd_idx] + ) + sample_kwargs_keys = tuple( + (k, v.shape, v.dtype, v.layout) + for k, v in sorted(sample_kwargs[per_callable_fwd_idx].items()) + ) + sample_keys = sample_args_keys + sample_kwargs_keys + + fwd_sample_qs[m_chunk].append((sample_keys, per_callable_fwd_idx)) + if consumed_sample_q.get(sample_keys, []): + reuse_fwd_idx = consumed_sample_q[sample_keys].pop(0) sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx] sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx] fwd_idx[m_chunk] += 1 @@ -219,7 +246,12 @@ def _make_graphed_callables( num_consumed_samples = min( len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk] ) - consumed_sample_q += fwd_sample_qs[m_chunk][:num_consumed_samples] + for sample_keys, per_callable_fwd_idx in fwd_sample_qs[m_chunk][ + :num_consumed_samples + ]: + if sample_keys not in consumed_sample_q: + consumed_sample_q[sample_keys] = [] + consumed_sample_q[sample_keys].append(per_callable_fwd_idx) fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:] if fp8_weight_caching: @@ -340,6 +372,16 @@ def _make_graphed_callables( def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument if isinstance(module, TransformerEngineBaseModule): visited_te_modules.add(module) + # If forward is called on a BasicOperation directly the hook will run + elif isinstance(module, BasicOperation): + visited_te_modules.add(module) + # If forward is called on a te.ops.Sequential it is not called on its constituent ops + elif isinstance(module, Sequential): + assert module._module_groups is not None, "Should have been initialized by warmup" + for module_group in module._module_groups: + if isinstance(module_group, OperationFuser): + for basic_op in module_group._basic_ops: + visited_te_modules.add(basic_op) # Run warmup and do the above filtering. with torch.cuda.stream(torch.cuda.Stream()): @@ -412,8 +454,8 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument per_callable_static_grad_inputs = [None] * len(flatten_sample_args) fwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks - static_grad_outputs = None - previous_per_callable_bwd_idx = None + static_grad_outputs_dict = {} + previous_chunk_last_callable_bwd_idx = None for c_id in _order: if c_id > 0: # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] @@ -426,7 +468,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument args = sample_args[per_callable_fwd_idx] kwargs = sample_kwargs[per_callable_fwd_idx] fwd_graph = fwd_graphs[per_callable_fwd_idx] - with torch.cuda.graph(fwd_graph, pool=mempool): + with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) @@ -436,6 +478,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument else: # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1] m_chunk = -c_id - 1 + previous_per_callable_bwd_idx = None for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))): per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no @@ -444,14 +487,26 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument static_outputs = per_callable_static_outputs[per_callable_bwd_idx] bwd_graph = bwd_graphs[per_callable_bwd_idx] # For now, assumes all static_outputs require grad - if not _reuse_graph_input_output_buffers or static_grad_outputs is None: + if _reuse_graph_input_output_buffers: # Note for _reuse_graph_input_output_buffers: grad output is only used # within backward, so we can reuse the same static buffers every time. + static_grad_outputs_keys = tuple( + (o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad + ) + if static_grad_outputs_keys in static_grad_outputs_dict: + static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys] + else: + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None + for o in static_outputs + ) + static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs + else: static_grad_outputs = tuple( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) if is_training: - with torch.cuda.graph(bwd_graph, pool=mempool): + with _graph_context_wrapper(bwd_graph, pool=mempool): grad_inputs = torch.autograd.grad( outputs=tuple(o for o in static_outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), @@ -486,19 +541,29 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument per_callable_static_outputs[per_callable_bwd_idx] = make_weak_ref( static_outputs ) - # Weak ref the static grad inputs of the previous backward pass. - # Note: After a backward pass, we assume Mcore will send the - # grad input to another pipeline parallel rank and that the - # communication is finished before the end of the next backward - # pass. + + # Weak ref the static grad inputs of the previous backward pass within the + # same chunk. if previous_per_callable_bwd_idx is not None: - per_callable_static_grad_inputs[previous_per_callable_bwd_idx] = ( - make_weak_ref( - per_callable_static_grad_inputs[previous_per_callable_bwd_idx] - ) + idx = previous_per_callable_bwd_idx + per_callable_static_grad_inputs[idx] = make_weak_ref( + per_callable_static_grad_inputs[idx] ) previous_per_callable_bwd_idx = per_callable_bwd_idx + # Weak ref the static grad inputs of the previous chunk's last backward + # pass. + # Note: After a chunk's backward pass, we assume Mcore will send the grad + # input to another pipeline parallel rank and that the communication is + # finished before the end of the next chunk's backward pass. + if l_no == 0: + if previous_chunk_last_callable_bwd_idx is not None: + idx = previous_chunk_last_callable_bwd_idx + per_callable_static_grad_inputs[idx] = make_weak_ref( + per_callable_static_grad_inputs[idx] + ) + previous_chunk_last_callable_bwd_idx = per_callable_bwd_idx + bwd_idx[m_chunk] += 1 else: # Capture forward graphs @@ -506,7 +571,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument per_callable_output_unflatten_spec = [] graph_id = 0 for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs): - with torch.cuda.graph(fwd_graph, pool=mempool): + with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) graph_callables[graph_id] = func graph_id += 1 @@ -528,7 +593,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) if is_training: - with torch.cuda.graph(bwd_graph, pool=mempool): + with _graph_context_wrapper(bwd_graph, pool=mempool): grad_inputs = torch.autograd.grad( outputs=tuple(o for o in static_outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), @@ -678,31 +743,41 @@ def new_fwd(*user_args, **user_kwargs): # run the graph, otherwise run the original forward method if func.training == graph_training_state: # Set the FP8 group from global amax reduction. - for m in func.modules(): - if ( - isinstance(m, TransformerEngineBaseModule) - and FP8GlobalStateManager.is_fp8_enabled() - ): + if FP8GlobalStateManager.is_fp8_enabled(): + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + for m in func.modules(): if m not in visited_te_modules: # Only Set the FP8 meta for the modules included by forward continue - fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - from transformer_engine.pytorch.attention.dot_product_attention import ( - DotProductAttention, - ) - - if ( - isinstance(m, DotProductAttention) - and not fp8_recipe.fp8_mha - and not fp8_recipe.fp8_dpa - ): - # Don't need to update FP8 meta for non-FP8 DPA - continue - m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() - m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - m.fp8_meta, - ) + if isinstance(m, TransformerEngineBaseModule): + from transformer_engine.pytorch.attention.dot_product_attention import ( + DotProductAttention, + ) + + if ( + isinstance(m, DotProductAttention) + and not fp8_recipe.fp8_mha + and not fp8_recipe.fp8_dpa + ): + # Don't need to update FP8 meta for non-FP8 DPA + continue + m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + m.fp8_meta, + ) + elif isinstance(m, BasicOperation): + for mode in ("forward", "backward"): + if m.num_quantizers(mode): + m._fp8_metas[mode][ + "fp8_group" + ] = FP8GlobalStateManager.get_fp8_group() + m._fp8_metas[mode][ + "recipe" + ] = FP8GlobalStateManager.get_fp8_recipe() + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + m._fp8_metas[mode], + ) return graphed(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs) @@ -725,7 +800,7 @@ def new_fwd(*user_args, **user_kwargs): def save_fp8_tensors( modules: Iterable[torch.nn.Module], - fp8_recipe: Recipe, + fp8_recipe: Optional[Recipe], ) -> Optional[List[Any]]: """ Returns the FP8 tensors for all modules @@ -744,7 +819,7 @@ def save_fp8_tensors( m.adjust_amax_history_length(fp8_recipe.amax_history_len) module_tensors = m.get_fp8_meta_tensors() elif isinstance(m, BasicOperation): - m.pre_first_forward(recipe=fp8_recipe) + m.reset_recipe_state(recipe=fp8_recipe) module_tensors = m._save_fp8_metas() fp8_tensors.append(module_tensors) return fp8_tensors @@ -779,9 +854,9 @@ def make_graphed_callables( num_warmup_iters: int = 3, allow_unused_input: bool = False, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, - fp8_enabled: bool = False, + fp8_enabled: SingleOrTuple[bool] = False, fp8_calibrating: bool = False, - fp8_recipe: Optional[DelayedScaling] = None, + fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, fp8_weight_caching: bool = False, _order: Optional[List[int]] = None, @@ -825,14 +900,15 @@ def make_graphed_callables( FP8-related parameters ---------------------- - fp8_enabled: bool, default = `True` - whether or not to enable fp8 + fp8_enabled: (tuple of) bool, default = `False` + whether or not to enable fp8. + If tuple, the length must match the number of modules. fp8_calibrating: bool, default = `False` calibration mode allows collecting statistics such as amax and scale data of fp8 tensors even when executing without fp8 enabled. This is useful for saving an inference ready fp8 checkpoint while training using a higher precision. - fp8_recipe: recipe.DelayedScaling, default = `None` + fp8_recipe: Recipe, default = `None` recipe used for FP8 training. fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` distributed group over which amaxes for the fp8 tensors @@ -848,14 +924,25 @@ def make_graphed_callables( """ set_capture_start() - fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe - # Handle single module. just_one_callable = False if not isinstance(modules, tuple): just_one_callable = True modules = (modules,) + if not isinstance(fp8_enabled, tuple): + assert isinstance(fp8_enabled, bool), "fp8_enabled must be a bool or a tuple of bools" + fp8_enabled = (fp8_enabled,) * len(modules) + else: + assert len(fp8_enabled) == len( + modules + ), f"fp8_enabled length ({len(fp8_enabled)}) must match modules length ({len(modules)})" + if any(fp8_enabled) and fp8_recipe is None: + fp8_recipe = get_default_fp8_recipe() + elif not any(fp8_enabled): + fp8_recipe = None + module_uses_fp8 = dict(zip((id(m) for m in modules), fp8_enabled)) + # Store FP8 tensors to reset later. saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) @@ -870,15 +957,15 @@ def wrap_autocast(block): old_call_funcs[block_cls] = block_cls.__call__ # Wrap the original call function of the module class. - def call_func(*args, **kwargs): + def call_func(self, *args, **kwargs): with fp8_autocast( - enabled=fp8_enabled, + enabled=module_uses_fp8.get(id(self), False), calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, fp8_group=fp8_group, _graph=True, ): - outputs = old_call_funcs[block_cls](*args, **kwargs) + outputs = old_call_funcs[block_cls](self, *args, **kwargs) return outputs block_cls.__call__ = call_func diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index cdd08766c..db2130877 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -138,30 +138,43 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: @jit_fuser def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor: """L2 normalization fused - inference version""" - x_squared = x.pow(2) + x_fp32 = x.float() + x_squared = x_fp32.pow(2) l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) - return x * rsqrt_norm + y_fp32 = x_fp32 * rsqrt_norm + return y_fp32.to(x.dtype) @jit_fuser def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]: """L2 normalization fused - training version that returns intermediate values""" - x_squared = x.pow(2) + x_fp32 = x.float() + x_squared = x_fp32.pow(2) l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) - rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) - y = x * rsqrt_norm + l2_norm_squared_eps = l2_norm_squared + eps + rsqrt_norm = torch.rsqrt(l2_norm_squared_eps) + y_fp32 = x_fp32 * rsqrt_norm + y = y_fp32.to(x.dtype) return y, rsqrt_norm @jit_fuser def l2normalization_backward_fused_( - grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float + grad_output: torch.Tensor, + x: torch.Tensor, + rsqrt_norm: torch.Tensor, + eps: float, ) -> torch.Tensor: """L2 normalization backward fused""" - x_dy_sum = (x * grad_output).sum(dim=-1, keepdim=True) - x_norm_squared = x.pow(2).sum(dim=-1, keepdim=True) + eps - return rsqrt_norm * (grad_output - x * x_dy_sum / x_norm_squared) + x_fp32 = x.float() + grad_output_fp32 = grad_output.float() + x_dy_sum = (x_fp32 * grad_output_fp32).sum(dim=-1, keepdim=True) + x_squared = x_fp32.pow(2) + l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) + x_norm_squared = l2_norm_squared + eps + dx_fp32 = rsqrt_norm * (grad_output_fp32 - x_fp32 * x_dy_sum / x_norm_squared) + return dx_fp32.to(x.dtype) def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: @@ -195,7 +208,10 @@ def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor def l2normalization_backward_fused( - grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float + grad_output: torch.Tensor, + x: torch.Tensor, + rsqrt_norm: torch.Tensor, + eps: float, ) -> torch.Tensor: """Disable native AMP for l2normalization_backward_fused_""" with gpu_autocast_ctx(enabled=False): diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index 5074d32aa..ac682190c 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -11,4 +11,4 @@ from .rmsnorm import RMSNorm from .fp8_padding import Fp8Padding from .fp8_unpadding import Fp8Unpadding -from .base import initialize_ub, destroy_ub +from .base import initialize_ub, destroy_ub, UserBufferQuantizationMode diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a6ab1b22a..b49e38544 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -10,6 +10,7 @@ import os import pickle import warnings +from enum import Enum from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager @@ -47,18 +48,16 @@ from ..tensor.fsdp2_allgather_tensor import FSDPAGTensor from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..utils import get_device_compute_capability, torch_get_autocast_gpu_dtype if IS_HIP_EXTENSION: from ..triton_kernels.cast import te_quantize_triton - -from ..utils import is_non_tn_fp8_gemm_supported +from ..utils import get_device_compute_capability, is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor +from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled - -__all__ = ["initialize_ub", "destroy_ub"] +__all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"] _2X_ACC_FPROP = False _2X_ACC_DGRAD = True @@ -72,6 +71,15 @@ layers_atomic_ring_exchange = [] +class UserBufferQuantizationMode(Enum): + """ + UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer. + """ + + NONE = "none" + FP8 = "fp8" + + def get_cublas_workspace_size_bytes() -> None: """Return workspace size needed for current architecture""" if IS_HIP_EXTENSION: @@ -126,8 +134,9 @@ def initialize_ub( shape: list, tp_size: int, use_fp8: bool = False, + quantization_modes: List[UserBufferQuantizationMode] = None, dtype: torch.dtype = torch.bfloat16, - ub_cfgs: Optional[dict] = None, + ub_cfgs: Optional[Union[dict, List[dict]]] = None, bootstrap_backend: Union[str, torch.distributed.Backend] = None, ) -> None: r""" @@ -143,7 +152,11 @@ def initialize_ub( tp_size : int number of GPUs in the tensor-parallel process group use_fp8 : bool = False - allocate the communication buffer for FP8 GEMM inputs/outputs + allocate the communication buffer for FP8 GEMM inputs/outputs. + DEPRECATED: Please use `quantization_modes` instead. + quantization_modes : List[UserBufferQuantizationMode] = None + if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list. + falls back to the legacy `use_fp8` parameter if `None` is provided. dtype : torch.dtype = torch.bfloat16 non-FP8 data type of the communication buffer when `use_fp8 = False` ub_cfgs: dict = None @@ -166,7 +179,8 @@ def initialize_ub( ``` for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", - "fc2_fprop", "fc2_dgrad"]`. + "fc2_fprop", "fc2_wgrad"]`. + a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes` bootstrap_backend : str = None `torch.distributed` communication backend for the all-gather, broadcast and barrier collectives during Userbuffers initialization. Not all backends are @@ -183,6 +197,28 @@ def initialize_ub( + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." ) + if not quantization_modes: + warnings.warn( + "Initializing Userbuffers with use_fp8 is deprecated. Please use quantization_modes" + " instead.", + DeprecationWarning, + ) + quantization_modes = [ + UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE + ] + else: + assert isinstance(quantization_modes, list), "quantization_modes must be a list" + assert all( + isinstance(mode, UserBufferQuantizationMode) for mode in quantization_modes + ), "quantization_modes must be a list of UserBufferQuantizationMode" + + if isinstance(ub_cfgs, dict) or ub_cfgs is None: + ub_cfgs = [ub_cfgs] * len(quantization_modes) + else: + assert len(ub_cfgs) == len( + quantization_modes + ), "Number of ub_cfgs settings must match number of quantization configurations" + global _ub_communicators assert _ub_communicators is None, "UB communicators are already initialized." _ub_communicators = {} @@ -265,22 +301,31 @@ def initialize_ub( "qkv_fprop", "qkv_dgrad", "proj_dgrad", + "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", + "fc2_wgrad", ] layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] # Default overlap methods for layers methods = { - "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], + "ring_exchange": [ + "qkv_fprop", + "fc1_fprop", + "proj_dgrad", + "fc2_dgrad", + ], "pipeline": ["proj_fprop", "fc2_fprop"], "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], + "external": ["proj_wgrad", "fc2_wgrad"], } # AG-RS overlap pairs of layers forming a tensor-parallel block ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} + external_gemm_to_overlap = {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"} global layers_atomic_ring_exchange layers_atomic_ring_exchange = [] @@ -315,6 +360,7 @@ def get_default_config(name): def add_ub( name: str, + quantization_mode: UserBufferQuantizationMode, method: str, is_reduce_scatter: bool, num_sm: int = 16, @@ -333,8 +379,10 @@ def add_ub( warnings.warn( "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." ) - assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." - if method == "bulk": + assert ( + quantization_mode == UserBufferQuantizationMode.FP8 + ), "Atomic GEMM overlap supported only for FP8 GEMM." + if method in ("bulk", "external"): warnings.warn( f"At {name}, atoimic GEMM not is supported for a bulk overlap." "Defaulting to `atomic_gemm=False`." @@ -363,7 +411,21 @@ def add_ub( if atomic_gemm and method == "ring_exchange": assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message - buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype + if name in external_gemm_to_overlap: + assert method == "external", ( + f"At {name}, `external` overlap method is specified, but the selected method is" + f" {method}" + ) + assert external_gemm_to_overlap[name] in methods["ring_exchange"], ( + f"At {name}, `external` overlap method is specified, but the external gemm" + f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" + ) + + buffer_dtype = ( + torch.uint8 + if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf) + else dtype + ) if method == "ring_exchange": ub_obj = tex.CommOverlapP2P( shape, # Communication buffer shape @@ -397,36 +459,47 @@ def add_ub( comm_priority=comm_priority, rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, ) - _ub_communicators[name] = ub_obj - - if ub_cfgs is not None: - for name in dgrad_reduce_scatter_overlap: - if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk": - wgrad_name = name.replace("dgrad", "wgrad") - assert wgrad_name not in ub_cfgs - layers_reduce_scatter_overlap.remove(wgrad_name) - layers_all_gather_overlap.remove(name) - layers_reduce_scatter_overlap.append(name) - methods["bulk"].remove(name) - new_method = ub_cfgs[name]["method"] - methods[new_method].append(name) - - for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: - ub_cfg = get_default_config(name) - if ub_cfgs is not None and name in ub_cfgs: - fp8_buf = (name in layers_all_gather_overlap) or ( - ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"] - ) - ub_cfg.update(ub_cfgs[name]) - ub_cfg["fp8_buf"] = fp8_buf - add_ub(name, **ub_cfg) + _ub_communicators[(name, quantization_mode)] = ub_obj + + for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs): + if user_ub_cfg is not None: + for name in dgrad_reduce_scatter_overlap: + if ( + name in user_ub_cfg + and "method" in user_ub_cfg[name] + and user_ub_cfg[name]["method"] != "bulk" + ): + wgrad_name = name.replace("dgrad", "wgrad") + assert wgrad_name not in user_ub_cfg + layers_reduce_scatter_overlap.remove(wgrad_name) + layers_all_gather_overlap.remove(name) + layers_reduce_scatter_overlap.append(name) + methods["bulk"].remove(name) + new_method = user_ub_cfg[name]["method"] + methods[new_method].append(name) + + for name in ( + methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"] + ): + ub_cfg = get_default_config(name) + if user_ub_cfg is not None and name in user_ub_cfg: + fp8_buf = (name in layers_all_gather_overlap) or ( + user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"] + ) + ub_cfg.update(user_ub_cfg[name]) + ub_cfg["fp8_buf"] = fp8_buf + add_ub(name, quantization_mode, **ub_cfg) -def get_ub(name: str): +def get_ub(name: str, use_fp8: bool): """Get userbuffer communicator corresponding to give key.""" + # For now use `use_fp8` boolean input as it matches the current design in the modules + # So favour simplicity until the correct design becomes clear. + # This is mainly an internal API so we don't need to worry about future changes + key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE) assert _ub_communicators is not None, "UB manager is not initialized." - assert name in _ub_communicators, f"UB for {name} is not registered." - return _ub_communicators[name] + assert key in _ub_communicators, f"UB for {name} with use_fp8={use_fp8} is not registered." + return _ub_communicators[key] def destroy_ub(): @@ -580,6 +653,7 @@ def __init__(self) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." self.name = None + self.next_iter_when_debug_should_be_run = 0 self.fp8_initialized = False self.fp8 = False self.fp8_calibration = False @@ -597,9 +671,10 @@ def __init__(self) -> None: self.fsdp_wrapped = False self.fsdp_group = None self._fp8_workspaces: Dict[str, QuantizedTensor] = {} - self.activation_dtype: Optional[torch.dtype] = None, - self.keep_fp8_weight_transpose_cache: bool = True, + self.activation_dtype: Optional[torch.dtype] = None + self.keep_fp8_weight_transpose_cache: bool = True self.use_fsdp2 = False + self.wgrad_accumulation_and_reduce_hooks = [] if not TEDebugState.debug_enabled: TEDebugState.initialize() @@ -1326,21 +1401,29 @@ def get_weight_workspace( # Try getting workspace from cache out = None - if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) - if quantizer is not None and isinstance(out, MXFP8TensorBase): + + # Reset cache if workspace is invalid + if out is not None and quantizer is not None: + reset_cache = False + if isinstance(out, Float8TensorBase): + if ( + not is_non_tn_fp8_gemm_supported() + and quantizer.columnwise_usage + and out._transpose is None + ): + reset_cache = True + elif isinstance(out, MXFP8TensorBase): if quantizer.rowwise_usage and out._rowwise_data is None: - out = None - del self._fp8_workspaces[cache_name] + reset_cache = True elif quantizer.columnwise_usage and out._columnwise_data is None: - out = None - del self._fp8_workspaces[cache_name] - - is_debug = isinstance(quantizer, DebugQuantizer) - is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor) - if is_debug != is_out_debug_tensor: + reset_cache = True + if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): + reset_cache = True + if reset_cache: out = None + del self._fp8_workspaces[cache_name] # Gather cached Fp8 workspace if it's distributed # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work @@ -1414,6 +1497,16 @@ def _load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) + def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook): + """ + This method is used to manually control the weight gradient accumulation and reduce. + This method should be called before the backward() method. + Set the skip_wgrad_accumulation_and_reduce to True to skip the weight gradient accumulation + and reduce in backward(); + And register the wgrad_accumulation_and_reduce_func to be called in backward_dw() method. + """ + self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook) + def backward_dw(self): """ Execute the delayed weight gradient computation. @@ -1424,14 +1517,57 @@ def backward_dw(self): with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): (wgrad, bgrad), _ = self.wgrad_store.pop() if not self.fuse_wgrad_accumulation: - unfused_weights = [getattr(self, name) for name in self.weight_names] - weight_tensor = noop_cat(unfused_weights) - if weight_tensor.grad is None: - weight_tensor.grad = wgrad.to(weight_tensor.dtype) + weight_tensor = noop_cat(self._get_weight_tensors()) + weight_tensor.grad = wgrad.to(weight_tensor.dtype) if self.use_bias: bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) if bias_tensor.grad is None: bias_tensor.grad = bgrad.to(bias_tensor.dtype) + del wgrad + del bgrad + for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: + wgrad_accumulation_and_reduce_hook() + + def is_debug_iter(self) -> bool: + """ + This function checks if the debug should be enabled for this layer. + """ + debug = TEDebugState.debug_enabled + if not debug: + return False + self._validate_name() + + # If layer is run first time in new iteration, + # we need to check if the debug should be enabled for this layer - + # maybe in previous iterations debug features returned information + # that no feature will be active for this layer for multiple next iterations. + started_new_iteration = TEDebugState.get_iteration() != getattr( + self, "debug_last_iteration", None + ) + if started_new_iteration: + if self.next_iter_when_debug_should_be_run is None: + debug = False + else: + debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run + self.debug_last_iteration = TEDebugState.get_iteration() + return debug + + def no_debug_features_active(self, quantizers): + """ + Checks if any debug feature is active for this layer. + """ + run_current = any_feature_enabled(quantizers) + + # Sometimes features inform that they will not be enabled for particular layer + # for multiple next iterations. + self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers) + + if not run_current: + return True + + if self.primary_weights_in_fp8: + raise RuntimeError("FP8 weights are not supported in debug mode.") + return False def _validate_name(self): """ @@ -1439,6 +1575,8 @@ def _validate_name(self): This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. If no name is assigned, it creates a default name with layer count as the variable. """ + if self.name is not None: + return assert TEDebugState.debug_enabled import nvdlfw_inspect.api as debug_api @@ -1487,29 +1625,3 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: " Please check the recipes assigned during fp8_model_init() and" " fp8_autocast() calls." ) - - def _turn_off_unsupported_features_in_debug(self): - if ( - getattr(self, "ub_bulk_wgrad", False) - or getattr(self, "ub_bulk_dgrad", False) - or getattr(self, "ub_overlap_ag", False) - or getattr(self, "ub_overlap_rs_dgrad", False) - or getattr(self, "ub_overlap_rs", False) - ): - import nvdlfw_inspect.api as debug_api - - debug_api.log_message( - "UserBuffers are not supported in debug module. " - "Using UB optimization will not affect the debug module. ", - level=logging.WARNING, - ) - if hasattr(self, "ub_bulk_wgrad"): - self.ub_bulk_wgrad = None - if hasattr(self, "ub_bulk_dgrad"): - self.ub_bulk_dgrad = None - if hasattr(self, "ub_overlap_ag"): - self.ub_overlap_ag = None - if hasattr(self, "ub_overlap_rs_dgrad"): - self.ub_overlap_rs_dgrad = None - if hasattr(self, "ub_overlap_rs"): - self.ub_overlap_rs = None diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index da66e68b4..5749d96c9 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -452,9 +452,6 @@ def handle_custom_ddp_from_mcore(weight, wgrad): else: wgrad_list = [None] * ctx.num_gemms - if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): - wgrad_list = [None] * ctx.num_gemms - if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() @@ -662,6 +659,12 @@ def __init__( self.reset_parameters(defer_init=device == "meta") + if self.wgrad_store.delay_wgrad_compute(): + for name, param in self.named_parameters(): + for i in range(self.num_gemms): + if name in (f"weight{i}", f"bias{i}"): + param.skip_backward_post_hook = True + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) @@ -742,7 +745,9 @@ def forward( if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: weight_tensors = self._get_weight_tensors() bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] @@ -817,19 +822,20 @@ def backward_dw(self): with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): (_, grad_biases_, _), tensor_list = self.wgrad_store.pop() wgrad_list = tensor_list[2] + weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fuse_wgrad_accumulation: for i in range(self.num_gemms): - weight_param = getattr(self, f"weight{i}") - if weight_param.grad is None: - weight_param.grad = wgrad_list[i].to(weight_param.dtype) + weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype) if self.use_bias: for i in range(self.num_gemms): - bias_param = getattr(self, f"bias{i}") - if bias_param.grad is None: - bias_param.grad = grad_biases_[i].to(bias_param.dtype) + if bias_params[i].grad is None: + bias_params[i].grad = grad_biases_[i].to(bias_params[i].dtype) del grad_biases_ del wgrad_list del tensor_list + for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: + wgrad_accumulation_and_reduce_hook() def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: """Customize quantizers based on current scaling recipe + linear.""" @@ -877,7 +883,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] * self.num_gemms weight_quantizers = [ self.quantizers["scaling_fwd"][ diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5af56b2e0..d1aeebc9f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -66,9 +66,8 @@ restore_from_saved, ) from ...debug.pytorch.debug_state import TEDebugState -from ...debug.pytorch.utils import any_feature_enabled from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer -from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase @@ -173,16 +172,23 @@ def forward( with_input_all_gather = parallel_mode == "column" and sequence_parallel # Configure Userbuffers communication (comm+GEMM overlap) + if debug: # turn off userbuffers in debug mode + ub_overlap_ag_fprop = False + ub_overlap_rs_fprop = False + ub_overlap_ag_dgrad = False + ub_overlap_rs_dgrad = False + ub_bulk_wgrad = False + ub_bulk_dgrad = False ub_obj = None ub_type = None ub_overlap_ag_fprop = ( ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output ) if ub_overlap_rs_fprop: - ub_obj = get_ub(ub_name + "_fprop") + ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.RS elif ub_overlap_ag_fprop: - ub_obj = get_ub(ub_name + "_fprop") + ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.AG # Configure quantizer for norm output @@ -190,9 +196,7 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) - if with_input_all_gather and isinstance( - input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) - ): + if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) @@ -369,6 +373,9 @@ def forward( if not weight.requires_grad and not return_layernorm_output: clear_tensor_data(ln_out, ln_out_total) ln_out = ln_out_total = None + elif with_input_all_gather and not return_layernorm_output_gathered: + clear_tensor_data(ln_out_total) + ln_out_total = None # ------------------------------------------------------ # Prepare output tensor @@ -593,23 +600,23 @@ def backward( dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -659,7 +666,7 @@ def backward( quantizer = None if ctx.input_quantizer is not None: quantizer = ctx.input_quantizer - if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): + if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) else: @@ -780,27 +787,36 @@ def backward( # Note: Synchronize tensor-parallel communication and # make sure required data is available if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): - # UB does not support overlapping grad output + # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we # can't reuse the grad output that was gathered # for the dgrad GEMM. We work around by explicitly - # overlapping the NCCL operation with the dgrad GEMM. + # overlapping the AG operation with the dgrad GEMM. + + # Get the communication stream from the dgrad GEMM to use for the AG + dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() + + # This object is separate from the ub_obj_wgrad object which is passed to the GEMM + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - # Get the communication stream from the dgrad GEMM and set it as the current torch stream - dgrad_comm_stream = ub_obj_dgrad.get_communication_stream() - with torch.cuda.stream(dgrad_comm_stream): - # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather - # This ensures that we don't start until all communication for the dgrad GEMM is complete - grad_output, mxfp8_grad_output_work = gather_along_first_dim( + # We use the send stream to copy into the userbuffers. + # This is the same stream that we will use to access the data in the AG, + # so we dont need to add any syncs yet. + with torch.cuda.stream(dgrad_send_stream): + grad_output, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj_overlap_wgrad, grad_outputs[0], + ctx.grad_output_quantizer, ctx.tp_group, - async_op=True, - quantizer=ctx.grad_output_quantizer, ) - # Synchronize with the main stream - mxfp8_grad_output_work.wait() + + # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm + tex.bulk_overlap_ag_with_external_gemm( + ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream + ) # Prepare input tensor # Note: Synchronize tensor-parallel communication and @@ -904,9 +920,19 @@ def wgrad_gemm( grad_bias = grad_bias_ del grad_bias_ - # Deallocate input tensor if permitted - if not ctx.return_layernorm_output: + # Deallocate input tensors if permitted + if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: + # Input tensors have not been exposed externally + clear_tensor_data(ln_out) + elif ctx.ln_out_needs_gather and ctx.return_layernorm_output_gathered: + # Non-gathered input has not been exposed externally + clear_tensor_data(ln_out) + if ctx.ln_out_needs_gather: + # Gathered input is internal clear_tensor_data(ln_out_total) + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal + clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM if ctx.ub_bulk_wgrad: @@ -1206,14 +1232,14 @@ def __init__( self.return_bias = return_bias self.apply_bias = self.use_bias and not return_bias self.return_layernorm_output = return_layernorm_output - self.return_layernorm_output_gathered = return_layernorm_output_gathered + self.return_layernorm_output_gathered = ( + return_layernorm_output_gathered if return_layernorm_output else False + ) self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.name = name - if TEDebugState.debug_enabled: - self._turn_off_unsupported_features_in_debug() # turn off userbuffers self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True self.use_fsdp2 = use_fsdp2 if IS_HIP_EXTENSION else False @@ -1433,6 +1459,11 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + if self.wgrad_store.delay_wgrad_compute(): + for name, param in self.named_parameters(): + if name in self.weight_names or name in self.bias_names: + param.skip_backward_post_hook = True + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) @@ -1517,9 +1548,8 @@ def forward( """ if is_in_onnx_export_mode(): return self.onnx_forward(inp, fp8_output) - debug = TEDebugState.debug_enabled - if debug: - self._validate_name() + + debug = self.is_debug_iter() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1529,13 +1559,19 @@ def forward( is_first_microbatch = False if self.ub_overlap_rs_fprop: - if get_ub(self.ub_name + "_fprop").is_fp8_ubuf(): + if get_ub( + self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): fp8_output = True if self.ub_overlap_rs_dgrad: - if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): + if get_ub( + self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer ) as inp: @@ -1548,13 +1584,9 @@ def forward( else self._get_debug_quantizers(fp8_output, fp8_grad) ) if debug: - if not any_feature_enabled(quantizers): - # If no feature is used, then run faster implementation with debug = False. - quantizers = self._get_quantizers(fp8_output, fp8_grad) + if self.no_debug_features_active(quantizers): debug = False - - if isinstance(weight_tensor, QuantizedTensor): - raise RuntimeError("FP8 weights are not supported in debug mode.") + quantizers = self._get_quantizers(fp8_output, fp8_grad) ( input_quantizer, @@ -1804,7 +1836,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 8772418c9..4492abe3e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -71,7 +71,6 @@ from ._common import apply_normalization, WeightGradStore from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..tensor.quantized_tensor import ( - QuantizedTensor, QuantizedTensorBase, Quantizer, prepare_for_saving, @@ -81,7 +80,6 @@ general_gemm, ) from ..export import is_in_onnx_export_mode, assert_warmed_up -from ...debug.pytorch.utils import any_feature_enabled from ...debug.pytorch.debug_state import TEDebugState if IS_HIP_EXTENSION: @@ -96,39 +94,45 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): # bf16 (recipe is None): return { "gelu": (tex.gelu, tex.dgelu, None), - "relu": (tex.relu, tex.drelu, None), "geglu": (tex.geglu, tex.dgeglu, None), - "reglu": (tex.reglu, tex.dreglu, None), - "swiglu": (tex.swiglu, tex.dswiglu, None), "qgelu": (tex.qgelu, tex.dqgelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "relu": (tex.relu, tex.drelu, None), + "reglu": (tex.reglu, tex.dreglu, None), "srelu": (tex.srelu, tex.dsrelu, None), + "sreglu": (tex.sreglu, tex.dsreglu, None), + "silu": (tex.silu, tex.dsilu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), } if recipe.delayed() or recipe.mxfp8(): # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] return { "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), - "relu": (tex.relu, tex.drelu, tex.dbias_drelu), "geglu": (tex.geglu, tex.dgeglu, None), - "reglu": (tex.reglu, tex.dreglu, None), - "swiglu": (tex.swiglu, tex.dswiglu, None), "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "relu": (tex.relu, tex.drelu, tex.dbias_drelu), + "reglu": (tex.reglu, tex.dreglu, None), "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), + "sreglu": (tex.sreglu, tex.dsreglu, None), + "silu": (tex.silu, tex.dsilu, tex.dbias_dsilu), + "swiglu": (tex.swiglu, tex.dswiglu, None), } # no activation fusion written yet # Per-tensor current scaling or fp8 blockwise scaling: [] if recipe.float8_current_scaling() or recipe.float8_block_scaling(): return { "gelu": (tex.gelu, tex.dgelu, None), - "relu": (tex.relu, tex.drelu, None), "geglu": (tex.geglu, tex.dgeglu, None), - "reglu": (tex.reglu, tex.dreglu, None), - "swiglu": (tex.swiglu, tex.dswiglu, None), "qgelu": (tex.qgelu, tex.dqgelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "relu": (tex.relu, tex.drelu, None), + "reglu": (tex.reglu, tex.dreglu, None), "srelu": (tex.srelu, tex.dsrelu, None), + "sreglu": (tex.sreglu, tex.dsreglu, None), + "silu": (tex.silu, tex.dsilu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), } raise NotImplementedError(f"Unhandled recipe type {recipe}") @@ -232,6 +236,12 @@ def forward( device = inp.device # Configure Userbuffers communication (comm+GEMM overlap) + if debug: # turn off userbuffers in debug mode + ub_overlap_ag = False + ub_overlap_rs = False + ub_overlap_rs_dgrad = False + ub_bulk_wgrad = False + ub_bulk_dgrad = False ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered ub_overlap_rs = ub_overlap_rs and is_grad_enabled @@ -247,9 +257,7 @@ def forward( if fc1_input_quantizer is None: raise ValueError("Missing quantizer for FC1 input tensor") fc1_input_quantizer.set_usage(rowwise=True, columnwise=backwards_needs_fc1_input) - if sequence_parallel and isinstance( - fc1_input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) - ): + if sequence_parallel and fc1_input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data fc1_input_quantizer.set_usage(columnwise=False) @@ -312,7 +320,7 @@ def forward( fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: # Copy into Userbuffers buffer - ub_obj_lnout = get_ub("fc1_fprop") + ub_obj_lnout = get_ub("fc1_fprop", fp8) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_lnout, ln_out, @@ -454,20 +462,25 @@ def forward( act_out = activation_func(fc1_out, None) act_out = tex.quantize(act_out, fc2_input_quantizer) else: - act_out = activation_func(fc1_out, fc2_input_quantizer) + if fp8_calibration: + act_out = activation_func(fc1_out, None) + else: + act_out = activation_func(fc1_out, fc2_input_quantizer) if not is_grad_enabled: clear_tensor_data(fc1_out) - if fp8_calibration: - fc2_input_quantizer.calibrate(act_out) - fc2_weight_quantizer.calibrate(fc2_weight) + if not fp8 and fp8_calibration: + if fc2_input_quantizer is not None: + fc2_input_quantizer.calibrate(act_out) + if fc2_weight_quantizer is not None: + fc2_weight_quantizer.calibrate(fc2_weight) # Configure Userbuffers reduce-scatter if needed ub_obj_fc2out = None reduce_scatter_out = None if ub_overlap_rs: - ub_obj_fc2out = get_ub("fc2_fprop") + ub_obj_fc2out = get_ub("fc2_fprop", fp8) dim_size = list(act_out.size()) dim_size[0] //= tp_world_size dim_size[-1] = fc2_weight.size(0) @@ -757,7 +770,7 @@ def backward( # Note: Cast to expected dtype and perform tensor-parallel communication ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: - ub_obj_fc2_dgrad = get_ub("fc2_dgrad") + ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8) ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, @@ -781,7 +794,7 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) if ctx.ub_bulk_dgrad: - ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_fc1_dgrad, ln_out, @@ -877,26 +890,37 @@ def backward( # Note: Synchronize tensor-parallel communication and # make sure required data is available if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): - # UB does not support overlapping grad output + # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we # can't reuse the grad output that was gathered # for the dgrad GEMM. We work around by explicitly - # overlapping the NCCL operation with the dgrad GEMM. + # overlapping the AG operation with the dgrad GEMM. + + # Get the communication stream from the dgrad GEMM to use for the AG + dgrad_send_stream, dgrad_recv_stream = ( + ub_obj_fc2_dgrad.get_communication_stream() + ) + + ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8) + ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - # Get the communication stream from the dgrad GEMM and set it as the current torch stream - dgrad_comm_stream = ub_obj_fc2_dgrad.get_communication_stream() - with torch.cuda.stream(dgrad_comm_stream): - # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather - # This ensures that we don't start until all communication for the dgrad GEMM is complete - grad_output, mxfp8_fc2_grad_output_work = gather_along_first_dim( + + # We use the send stream to copy into the userbuffers. + # This is the same stream that we will use to access the data in the AG, + # so we dont need to add any syncs yet. + with torch.cuda.stream(dgrad_send_stream): + grad_output, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj_fc2_wgrad, grad_outputs[0], + ctx.fc2_grad_output_quantizer, ctx.tp_group, - async_op=True, - quantizer=ctx.fc2_grad_output_quantizer, ) - # Synchronize with the main stream - mxfp8_fc2_grad_output_work.wait() + + # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm + tex.bulk_overlap_ag_with_external_gemm( + ub_obj_fc2_wgrad, dgrad_send_stream, dgrad_recv_stream + ) # Prepare input tensor # Note: Synchronize tensor-parallel communication and @@ -1045,16 +1069,16 @@ def fc2_wgrad_gemm( fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] if ctx.ub_overlap_rs_dgrad: # Overlap DGRAD+RS - ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap ln_out all-gather with DGRAD compute - ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap FC1 DGRAD reduce-scatter with WGRAD compute - ub_obj_fc1_wgrad = get_ub("fc1_wgrad") + ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8) ub_type_fc1_wgrad = tex.CommOverlapType.RS @@ -1210,7 +1234,6 @@ def fc1_wgrad_gemm( "with Userbuffers (tensor-parallel communication overlapping)" ) ctx.wgrad_store.put([ln_out_total, dact], fc1_wgrad_gemm) - fc1_wgrad = None if fuse_gemm_and_bias_fc1_wgrad: fc1_bias_grad = None else: @@ -1402,7 +1425,7 @@ def fc1_wgrad_gemm( class LayerNormMLP(TransformerEngineBaseModule): r""" Applies layer normalization on the input followed by the MLP module, consisting of - 2 successive linear transformations, separated by the GeLU activation. + 2 successive linear transformations, separated by the activation function. Parameters ---------- @@ -1418,7 +1441,8 @@ class LayerNormMLP(TransformerEngineBaseModule): type of normalization applied. activation : str, default = 'gelu' activation function used. - Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu', 'qgelu', 'srelu'. + Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', + 'silu', and 'swiglu'. init_method : Callable, default = `None` used for initializing FC1 weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. @@ -1576,13 +1600,14 @@ def __init__( self.gemm_gelu_fusion = ( bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and self.activation == "gelu" - and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) + and all( + ("fc1_fprop", use_fp8) not in _ub_communicators + or not get_ub("fc1_fprop", use_fp8).is_atomic_gemm() + for use_fp8 in [False, True] + ) ) self.name = name - if TEDebugState.debug_enabled: - self._turn_off_unsupported_features_in_debug() # turn off userbuffers - self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) if tp_group is None: @@ -1639,7 +1664,7 @@ def __init__( self.layer_norm_bias = None # FC1 init - if self.activation in ["reglu", "geglu", "qgeglu", "swiglu"]: + if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu"]: fc1_output_features = 2 * self.size_per_partition else: fc1_output_features = self.size_per_partition @@ -1699,6 +1724,10 @@ def __init__( warmup_jit_bias_gelu_all_dtypes( self.size_per_partition, seq_length, micro_batch_size ) + if self.wgrad_store.delay_wgrad_compute(): + for name, param in self.named_parameters(): + if name in ["fc1_weight", "fc2_weight", "fc1_bias", "fc2_bias"]: + param.skip_backward_post_hook = True # These many SMs are subtracted from the total SM count when calling forward # and backward LayerNorm C APIs. These envvars can be used to prevent the LN @@ -1781,9 +1810,8 @@ def forward( """ if is_in_onnx_export_mode(): return self.onnx_forward(inp) - debug = TEDebugState.debug_enabled - if debug: - self._validate_name() + + debug = self.is_debug_iter() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1794,10 +1822,12 @@ def forward( fp8_output = False if self.ub_overlap_rs: - if get_ub("fc2_fprop").is_fp8_ubuf(): + if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): fp8_output = True - with self.prepare_forward(inp, num_gemms=2) as inp: + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward(inp, num_gemms=2) as inp: quantizers = ( self._get_quantizers(fp8_output) @@ -1805,12 +1835,9 @@ def forward( else self._get_debug_quantizers(fp8_output) ) if debug: - if not any_feature_enabled(quantizers): - quantizers = self._get_quantizers(fp8_output) + if self.no_debug_features_active(quantizers): debug = False - - if isinstance(self.fc1_weight, QuantizedTensor): - raise RuntimeError("FP8 weights are not supported in debug mode.") + quantizers = self._get_quantizers(fp8_output) # Get quantizers ( @@ -1935,7 +1962,7 @@ def _get_quantizers(self, fp8_output): fc2_grad_output_quantizer, ) = [None] * 10 fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers() - if self.fp8: + if self.fp8 or self.fp8_calibration: fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer.internal = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] @@ -2021,14 +2048,17 @@ def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Ten activation_map = { "gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), - "relu": torch.nn.functional.relu, "geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], - "reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], - "swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + "qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh") * x.chunk(2, -1)[1], - "qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), - "srelu": torch.nn.functional.softplus, + "relu": torch.nn.functional.relu, + "reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + "srelu": lambda x: torch.nn.functional.relu(x) ** 2, + "sreglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) ** 2 + * x.chunk(2, -1)[1], + "silu": torch.nn.functional.silu, + "swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], } if self.activation not in activation_map: raise ValueError(f"Unsupported activation in onnx export: {self.activation}") @@ -2149,7 +2179,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None, None] fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] fc1_weight_quantizer.internal = True @@ -2206,11 +2236,11 @@ def backward_dw(self): if self.fc1_bias.grad is None: self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype) if not self.fuse_wgrad_accumulation: - if self.fc2_weight.grad is None: - self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype) - if self.fc1_weight.grad is None: - self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype) + self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype) + self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype) del fc2_bias_grad_ del fc2_wgrad del fc1_wgrad del fc1_bias_grad + for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: + wgrad_accumulation_and_reduce_hook() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index fb5592540..88ed6356b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -67,12 +67,9 @@ ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...debug.pytorch.debug_state import TEDebugState -from ...debug.pytorch.utils import any_feature_enabled from torch.utils.cpp_extension import IS_HIP_EXTENSION __all__ = ["Linear"] @@ -144,13 +141,19 @@ def forward( ) # Configure Userbuffers communication (comm+GEMM overlap) + if debug: # turn off userbuffers in debug mode + ub_overlap_rs_fprop = False + ub_overlap_ag_fprop = False + ub_overlap_rs_dgrad = False + ub_bulk_wgrad = False + ub_bulk_dgrad = False ub_obj = None ub_type = None if ub_overlap_rs_fprop: - ub_obj = get_ub(ub_name + "_fprop") + ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.RS elif ub_overlap_ag_fprop: - ub_obj = get_ub(ub_name + "_fprop") + ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.AG # ------------------------------------------------------ @@ -175,16 +178,19 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if not isinstance(inputmat, QuantizedTensorBase): - input_quantizer.set_usage( - rowwise=True, columnwise=backward_needs_input and not save_original_input - ) + own_quantized_input = True + input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) + if save_original_input: + # No need for column-wise data since this + # tensor will not be cached for backward pass + input_quantizer.set_usage(columnwise=False) + own_quantized_input = False inputmat = input_quantizer(inputmat) - own_quantized_input = True else: inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP @@ -319,6 +325,13 @@ def forward( # Finished forward GEMM... # ------------------------------------------------------ + # Deallocate GEMM input tensor if no longer needed + # TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically + # deallocated by GC. Manually deallocating is a temporary hack. + if with_input_all_gather_nccl: + clear_tensor_data(inputmat_total) + inputmat_total = None + # ------------------------------------------------------ # Prepare output tensor # Note: Perform tensor-parallel communication @@ -352,23 +365,30 @@ def forward( inputmat = inp ctx.weight_quantizer = weight_quantizer - saved_inputmat = None ctx.backward_input_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) + # Discard unneeded data in input tensor + if ( + backward_needs_input + and own_quantized_input + and isinstance(inputmat, QuantizedTensorBase) + ): + if ( + ctx.backward_input_needs_gather + and weight_quantizer.supports_only_rowwise_all_gather() + ): + # All-gather is not supported with FP8 column-wise data + inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + else: + # Discard row-wise data since it is not needed in backward pass + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Cached input tensor + saved_inputmat = None if backward_needs_input: - if not save_original_input: - if own_quantized_input and isinstance(inputmat, QuantizedTensorBase): - # For sequence parallel in vanilla FP8, rowwise data is - # to gather the input. For MXFP8, columnwise only data - # can be allgathered. - if ( - isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) - or not ctx.backward_input_needs_gather - ): - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) saved_inputmat = inputmat # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. @@ -519,23 +539,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -558,6 +578,19 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # usage for only dgrad GEMM. quantizer.set_usage(columnwise=False) + # Adjust the quantization direction approach depending + # on whether wgrad calculations will be performed. + # NOTE: If requires_dgrad is False, disabling `rowwise` quantization and keeping `columnwise` quantization + # results in `Assertion failed: output_tensor->has_data(). Quantizing in only the columnwise direction not supported yet!` + # NOTE: For `ctx.bias is True`, selected quantize kernel errors with + # `cast_kernels.cuh:1322 in function fp8_quantize_arch_l_100: Not implemented scaling mode or fusion: NVTE_DELAYED_TENSOR_SCALING or IS_DBIAS=true on GPU with compute capability < 10.0.` + if ( + not ctx.use_bias + and not ctx.requires_wgrad + and ctx.grad_output_quantizer is not None + ): + ctx.grad_output_quantizer.set_usage(columnwise=False) + # Prepare grad output tensor nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") ( @@ -584,11 +617,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - input_is_quantized = isinstance(inputmat, QuantizedTensorBase) if ctx.fp8 or ctx.debug: - if not input_is_quantized: + if isinstance(inputmat, QuantizedTensorBase): + # Input tensor is already quantized + pass + elif ctx.debug: + # Debug quantizer will be applied immediately before wgrad GEMM + pass + else: + # Quantize input tensor quantizer = ctx.input_quantizer - if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): + if quantizer.supports_only_rowwise_all_gather(): + # All-gather is not supported with FP8 column-wise data quantizer.set_usage( rowwise=True, columnwise=not ctx.backward_input_needs_gather, @@ -597,7 +637,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], quantizer.set_usage(rowwise=False, columnwise=True) inputmat = quantizer(inputmat) else: - if input_is_quantized: + if isinstance(inputmat, QuantizedTensorBase): inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) else: inputmat = cast_if_needed(inputmat, ctx.activation_dtype) @@ -605,7 +645,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], quantizer = None if ctx.fp8 or ctx.debug: quantizer = ctx.input_quantizer - if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): + if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) else: @@ -740,26 +780,36 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: Synchronize tensor-parallel communication and # make sure required data is available if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): - # UB does not support overlapping grad output + # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we # can't reuse the grad output that was gathered # for the dgrad GEMM. We work around by explicitly - # overlapping the NCCL operation with the dgrad GEMM. + # overlapping the AG operation with the dgrad GEMM. + + # Get the communication stream from the dgrad GEMM to use for the AG + dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() + + # This object is separate from the ub_obj_wgrad object which is passed to the GEMM + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - # Get the communication stream from the dgrad GEMM and set it as the current torch stream - dgrad_comm_stream = ub_obj_dgrad.get_communication_stream() - with torch.cuda.stream(dgrad_comm_stream): - # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather - # This ensures that we don't start until all communication for the dgrad GEMM is complete - grad_output, grad_output_work = gather_along_first_dim( + + # We use the send stream to copy into the userbuffers. + # This is the same stream that we will use to access the data in the AG, + # so we dont need to add any syncs yet. + with torch.cuda.stream(dgrad_send_stream): + grad_output, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj_overlap_wgrad, grad_output_arg, + ctx.grad_output_quantizer, ctx.tp_group, - async_op=True, - quantizer=ctx.grad_output_quantizer, ) - # Synchronize with the main stream - grad_output_work.wait() + + # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm + tex.bulk_overlap_ag_with_external_gemm( + ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream + ) if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorBase): @@ -850,9 +900,16 @@ def wgrad_gemm( grad_bias = grad_bias_ del grad_bias_ - # Deallocate input tensor if permitted + # Deallocate tensors if permitted if ctx.owns_input: + # Input tensor is internal clear_tensor_data(inputmat_total) + elif ctx.backward_input_needs_gather: + # Gathered input tensor is internal + clear_tensor_data(inputmat_total) + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal + clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM if ctx.ub_bulk_wgrad: @@ -1098,9 +1155,6 @@ def __init__( self.save_original_input = save_original_input self.name = name - if TEDebugState.debug_enabled: - self._turn_off_unsupported_features_in_debug() # turn off userbuffers - self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True self.use_fsdp2 = use_fsdp2 if IS_HIP_EXTENSION else False @@ -1294,6 +1348,11 @@ def __init__( else: self.gemm_bias_unfused_add = False + if self.wgrad_store.delay_wgrad_compute(): + for name, param in self.named_parameters(): + if name in self.weight_names or name in self.bias_names: + param.skip_backward_post_hook = True + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) @@ -1359,9 +1418,7 @@ def forward( if is_in_onnx_export_mode(): return self.onnx_forward(inp, fp8_output) - debug = TEDebugState.debug_enabled - if debug: - self._validate_name() + debug = self.is_debug_iter() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1371,13 +1428,19 @@ def forward( is_first_microbatch = False if self.ub_overlap_rs_fprop: - if get_ub(self.ub_name + "_fprop").is_fp8_ubuf(): + if get_ub( + self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): fp8_output = True if self.ub_overlap_rs_dgrad: - if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): + if get_ub( + self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward( inp, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: @@ -1389,14 +1452,11 @@ def forward( if not debug else self._get_debug_quantizers(fp8_output, fp8_grad) ) + if debug: - if not any_feature_enabled(quantizers): - # If no feature is used, then run faster implementation with debug = False. - quantizers = self._get_quantizers(fp8_output, fp8_grad) + if self.no_debug_features_active(quantizers): debug = False - - if isinstance(weight_tensor, QuantizedTensor): - raise RuntimeError("FP8 weights are not supported in debug mode.") + quantizers = self._get_quantizers(fp8_output, fp8_grad) ( input_quantizer, @@ -1636,7 +1696,7 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/onnx_extensions.py b/transformer_engine/pytorch/onnx_extensions.py index e34fd7846..38df5fc54 100644 --- a/transformer_engine/pytorch/onnx_extensions.py +++ b/transformer_engine/pytorch/onnx_extensions.py @@ -112,7 +112,9 @@ def onnx_quantize_fp8_symbolic( doc="TRT FP8 Quantize Linear used for inference.", inputs=[ defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"), - defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for quantization"), + defs.OpSchema.FormalParameter( + "scale_inv", "tensor(float)", "Inverse scale factor for quantization" + ), ], outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")], ) @@ -126,11 +128,10 @@ def onnx_quantize_fp8_symbolic( @torch.library.custom_op("tex::fp8_dequantize", mutates_args=[]) -def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor: +def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor: """Dequantize from Float8Tensor used for inference.""" - scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device) quantizer = Float8Quantizer( - scale_tensor, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3 + 1 / scale_inv, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3 ) quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32) return quantizer_tensor.dequantize() @@ -143,10 +144,9 @@ def _(tensor: torch.Tensor, _) -> torch.Tensor: def onnx_dequantize_fp8_symbolic( - tensor: onnxscript.onnx_types.TensorType, scale: float + tensor: onnxscript.onnx_types.TensorType, scale_inv: onnxscript.onnx_types.TensorType ) -> onnxscript.onnx_types.TensorType: """Symbolic dequantize from Float8Tensor used for inference.""" - scale_inv = op.Constant(value_float=1 / scale) return TRT_FP8DequantizeLinear(tensor, scale_inv) @@ -157,7 +157,9 @@ def onnx_dequantize_fp8_symbolic( doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.", inputs=[ defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"), - defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for dequantization"), + defs.OpSchema.FormalParameter( + "scale_inv", "tensor(float)", "Inverse scale factor for dequantization" + ), ], outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")], ) @@ -166,6 +168,43 @@ def onnx_dequantize_fp8_symbolic( opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema ) +# ONNX FP8 Current Scaling Quantization + + +@torch.library.custom_op("tex::fp8_cs_quantize", mutates_args=[]) +def onnx_cs_quantize_fp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize to FP8 with current scaling; returns (uint8, scale_inv).""" + if tensor.dtype != torch.float32: + tensor = tensor.to(torch.float32) + amax = tensor.abs().max() + eps = torch.tensor(1e-12, dtype=torch.float32, device=tensor.device) + amax = torch.maximum(amax, eps) + fp8_max = torch.tensor(448, dtype=torch.float32, device=tensor.device) + scale = fp8_max / amax + q = torch.ops.tex.fp8_quantize(tensor, scale) + scale_inv = 1 / scale + return q, scale_inv + + +@onnx_cs_quantize_fp8_op.register_fake +def _(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.ones( + 1, dtype=torch.float32, device=tensor.device + ) + + +def onnx_quantize_fp8_cs_symbolic( + tensor: onnxscript.onnx_types.TensorType, +): + """Symbolic quantize with current scaling; computes scale_inv from tensor.""" + # scale_inv = 1 / max(abs(tensor)) + amax = op.ReduceMax(op.Abs(tensor), keepdims=0) + eps = op.Constant(value_float=1.0e-12) + amax = op.Max(amax, eps) + scale_inv = op.Div(amax, op.Constant(value_float=448.0)) + q = TRT_FP8QuantizeLinear(tensor, scale_inv) + return q, scale_inv + # ONNX MXFP8 Quantization @@ -194,12 +233,12 @@ def onnx_quantize_mxfp8_symbolic( tensor: onnxscript.onnx_types.TensorType, ) -> Tuple[onnxscript.onnx_types.TensorType, onnxscript.onnx_types.TensorType]: """Symbolic quantize to MXFP8Tensor used for inference.""" - tensor_out, scale_inv_out = TRT_MXFP8QuantizeLinear(tensor) + tensor_out, scale_inv_out = TRT_MXFP8DynamicQuantize(tensor) return tensor_out, scale_inv_out schema = defs.OpSchema( - name="TRT_MXFP8QuantizeLinear", + name="TRT_MXFP8DynamicQuantize", domain="trt", since_version=1, doc="TRT MXFP8 Quantize Linear used for inference.", @@ -214,8 +253,8 @@ def onnx_quantize_mxfp8_symbolic( ], ) -TRT_MXFP8QuantizeLinear = onnxscript.values.Op( - opset=trt_opset, name="TRT_MXFP8QuantizeLinear", op_schema=schema +TRT_MXFP8DynamicQuantize = onnxscript.values.Op( + opset=trt_opset, name="TRT_MXFP8DynamicQuantize", op_schema=schema ) @@ -356,6 +395,7 @@ def onnx_attention_mask_func( torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic, torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic, torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic, + torch.ops.tex.fp8_cs_quantize.default: onnx_quantize_fp8_cs_symbolic, torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic, torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic, torch.ops.tex.layernorm.default: onnx_layernorm_symbolic, diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 8e997428f..99bbc34c4 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -29,7 +29,9 @@ def maybe_dequantize( if is_quantized_tensor(tensor): return tensor.dequantize(dtype=dtype) if dtype is not None and tensor.dtype != dtype: - return tensor.to(dtype) + tensor = tensor.to(dtype) + if not tensor.is_contiguous(): + tensor = tensor.contiguous() return tensor diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index c69e3df02..2c903675f 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,12 +4,14 @@ """Single tensor operations supported by the operation fuser.""" -from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU -from .add_in_place import AddInPlace +from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU +from .add_extra_input import AddExtraInput from .all_gather import AllGather from .all_reduce import AllReduce from .basic_linear import BasicLinear from .bias import Bias +from .constant_scale import ConstantScale +from .dropout import Dropout from .identity import Identity from .l2normalization import L2Normalization from .layer_norm import LayerNorm diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index c077829a3..22779b601 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -11,12 +11,25 @@ import torch import transformer_engine_torch as tex -from ...fp8 import FP8GlobalStateManager +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext from .._common import maybe_dequantize +__all__ = [ + "GELU", + "GEGLU", + "QGELU", + "QGEGLU", + "ReLU", + "ReGLU", + "SReLU", + "SReGLU", + "SiLU", + "SwiGLU", +] + class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): r"""Apply activation function @@ -71,7 +84,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: @@ -87,14 +100,8 @@ def op_forward( # Check input tensor x = maybe_dequantize(input_.contiguous(), dtype) - # Check if quantized compute is enabled - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - quantizer = None - if with_quantized_compute: - quantizer = next_op_input_quantizer - # Launch kernel - y = self._activation_forward_impl(x, quantizer) + y = self._activation_forward_impl(x, next_op_input_quantizer) # Quantize input to FP8 before caching if needed if self.cache_quantized_input: @@ -103,10 +110,12 @@ def op_forward( x = input_quantizer(x) # Save state for backward pass - ctx.save_for_backward(x) - ctx.with_quantized_compute = with_quantized_compute - ctx.dtype = dtype - ctx.prev_op_grad_input_quantizer = prev_op_grad_input_quantizer + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x) + ctx.save_for_backward(x) + ctx.dtype = dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer return y @@ -125,13 +134,8 @@ def op_backward( # Check grad output tensor dy = maybe_dequantize(grad_output.contiguous(), x.dtype) - # Check if quantized compute is enabled - quantizer = None - if ctx.with_quantized_compute: - quantizer = ctx.prev_op_grad_input_quantizer - # Launch kernel - dx = self._activation_backward_impl(dy, x, quantizer) + dx = self._activation_backward_impl(dy, x, ctx.prev_op_grad_output_quantizer) # Clear input tensor if possible clear_tensor_data(x) @@ -159,37 +163,75 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dgelu(*args, **kwargs) -class ReLU(_ActivationOperation): - r"""Rectified linear unit +class GEGLU(_ActivationOperation): + r"""Gaussian Error Gated Linear Unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: .. math:: - \text{ReLU}(x) = \max(x,0) + \text{GEGLU}(a,b) = \text{GELU}(a) * b + + where + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.relu(*args, **kwargs) + return tex.geglu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.drelu(*args, **kwargs) + return tex.dgeglu(*args, **kwargs) -class GEGLU(_ActivationOperation): - r"""Gaussian error gated linear unit +class QGELU(_ActivationOperation): + r"""Quick Gaussian Error Linear Unit + + Quick GELU from `HuggingFace`__ + and `paper`__. + + .. math:: + + \text{QGELU}(x) \approx x * \sigma(1.702 * x) + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.qgelu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dqgelu(*args, **kwargs) + + +class QGEGLU(_ActivationOperation): + r"""Quick Gaussian Error Gated Linear Unit The input tensor is split into chunks :math:`a` and :math:`b` along the last dimension and the following is computed: .. math:: - \text{GEGLU}(a,b) = \text{GELU}(a) * b + \text{QGEGLU}(a,b) = \text{QGELU}(a) * b where .. math:: - \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + \text{QGELU}(x) \approx x * \sigma(1.702 * x) .. warning:: @@ -199,19 +241,33 @@ class GEGLU(_ActivationOperation): the first half of the input tensor, while PyTorch applies it to the second half. - See `GLU Variants Improve Transformer`__. + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.qgeglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dqgeglu(*args, **kwargs) + + +class ReLU(_ActivationOperation): + r"""Rectified Linear Unit + + .. math:: + + \text{ReLU}(x) = \max(x,0) """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.geglu(*args, **kwargs) + return tex.relu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.dgeglu(*args, **kwargs) + return tex.drelu(*args, **kwargs) class ReGLU(_ActivationOperation): - r"""Rectified gated linear unit + r"""Rectified Gated Linear Unit The input tensor is split into chunks :math:`a` and :math:`b` along the last dimension and the following is computed: @@ -239,6 +295,67 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dreglu(*args, **kwargs) +class SReLU(_ActivationOperation): + r"""Squared Rectified Linear Unit + + .. math:: + + \text{SReLU}(x) = \max(x^2,0) + + See `Primer: Searching for Efficient Transformers for Language Modeling`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.srelu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dsrelu(*args, **kwargs) + + +class SReGLU(_ActivationOperation): + r"""Squared Rectified Gated Linear Unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{SReGLU}(a,b) = \max(a^2,0) * b + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.sreglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dsreglu(*args, **kwargs) + + +class SiLU(_ActivationOperation): + r"""Sigmoid Linear Unit + + .. math:: + + \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.silu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dsilu(*args, **kwargs) + + class SwiGLU(_ActivationOperation): r"""Swish gated linear unit diff --git a/transformer_engine/pytorch/ops/basic/add_in_place.py b/transformer_engine/pytorch/ops/basic/add_extra_input.py similarity index 68% rename from transformer_engine/pytorch/ops/basic/add_in_place.py rename to transformer_engine/pytorch/ops/basic/add_extra_input.py index e1493d3c7..1fcfa0466 100644 --- a/transformer_engine/pytorch/ops/basic/add_in_place.py +++ b/transformer_engine/pytorch/ops/basic/add_extra_input.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Fusible operation for in-place add.""" +"""Fusible operation for adding extra input tensor.""" from __future__ import annotations from collections.abc import Iterable @@ -18,16 +18,17 @@ from transformer_engine.pytorch.tensor import Quantizer -class AddInPlace(BasicOperation): - """Add in-place +class AddExtraInput(BasicOperation): + """Add extra input tensor This operation requires an extra tensor input to the operation - fuser. The main input is added in-place to the extra input, and a - view of the extra input is output. + user. It returns the sum of the main input and the extra input. + If in_place=True, the main input is added in-place to the extra + input, and a view of the extra input is output. - This operation is considered an advanced feature and most users - are discouraged from using it. In-place operations break some - autograd assumptions and they can result in subtle, esoteric bugs. + Using this operation with in_place=True is considered an advanced + feature and most users are discouraged from it. In-place operations + break some autograd assumptions and they can result in subtle, esoteric bugs. Compare to `MakeExtraOutput`, which does a similar operation in the backward pass. @@ -37,6 +38,10 @@ class AddInPlace(BasicOperation): # Operation expects buffer for output tensor num_extra_inputs: int = 1 + def __init__(self, *, in_place: bool = False): + super().__init__() + self._in_place = in_place + def op_forward(self, *args, **kwargs) -> None: raise RuntimeError( "{self.__class__.__name__} operation has " @@ -59,12 +64,17 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: - output = basic_op_extra_inputs[0][0].detach() - output += input_ + extra_input = basic_op_extra_inputs[0][0] + if self._in_place: + extra_input = extra_input.detach() + extra_input += input_ + output = extra_input + else: + output = extra_input + input_ return output, [()] def fuser_backward( diff --git a/transformer_engine/pytorch/ops/basic/all_gather.py b/transformer_engine/pytorch/ops/basic/all_gather.py index 0df165a06..bcd3c1417 100644 --- a/transformer_engine/pytorch/ops/basic/all_gather.py +++ b/transformer_engine/pytorch/ops/basic/all_gather.py @@ -40,7 +40,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: out: torch.Tensor diff --git a/transformer_engine/pytorch/ops/basic/all_reduce.py b/transformer_engine/pytorch/ops/basic/all_reduce.py index af928dd24..d8c1eb006 100644 --- a/transformer_engine/pytorch/ops/basic/all_reduce.py +++ b/transformer_engine/pytorch/ops/basic/all_reduce.py @@ -42,7 +42,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 59fc09607..70c70c54d 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -12,28 +12,32 @@ import torch -from transformer_engine.pytorch.module.base import get_workspace from ...cpp_extensions import general_gemm +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import ( CudaRNGStatesTracker, gather_along_first_dim, reduce_scatter_along_first_dim, ) from ...fp8 import FP8GlobalStateManager, Recipe -from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD +from ...module.base import ( + _2X_ACC_FPROP, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, + get_dummy_wgrad, + get_workspace, +) from ...tensor import Quantizer -from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer -from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer -from ...tensor.mxfp8_tensor import MXFP8Quantizer +from ...tensor.float8_tensor import Float8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase -from ..op import BasicOperation, OperationContext -from .._common import maybe_dequantize, is_quantized_tensor from ...utils import ( canonicalize_device, canonicalize_dtype, clear_tensor_data, devices_match, ) +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize, is_quantized_tensor def _wait_async(handle: Optional[Any]) -> None: @@ -75,7 +79,8 @@ class BasicLinear(BasicOperation): weight's `main_grad` attribute instead of relying on PyTorch autograd. The weight's `main_grad` must be set externally and there is no guarantee that `grad` will be set or be - meaningful. + meaningful. This is primarily intented to integrate with + Megatron-LM. userbuffers_options, dict, optional Options for overlapping tensor-parallel communication with compute using Userbuffers. This feature is highly @@ -291,10 +296,19 @@ def reset_parameters(self) -> None: # Quantize if needed if self._with_quantized_weight: quantizer = self.get_quantizer("forward", 1) + if quantizer is None: + raise RuntimeError( + "Tried to quantize weight with deferred initialization " + "due to meta device, but no quantizer was available. " + "This is most likely because the weight was initialized " + "within fp8_model_init, but the forward pass was not " + "performed within fp8_autocast." + ) quantizer.set_usage( rowwise=True, columnwise=torch.is_grad_enabled(), ) + quantizer.internal = False with torch.no_grad(): weight = quantizer(weight) @@ -303,72 +317,52 @@ def reset_parameters(self) -> None: weight = torch.nn.Parameter(weight) self.weight = weight - def pre_first_forward( - self, - *, - recipe: Optional[Recipe], - ) -> None: - super().pre_first_forward(recipe=recipe) - - # Initialize weights if needed - weight = self.weight - if weight.device.type == "meta": + def pre_first_fuser_forward(self) -> None: + super().pre_first_fuser_forward() + if self.weight.device.type == "meta": self.reset_parameters() - weight = self.weight - # Configure quantizers - if recipe is not None: - input_quantizer = self.get_quantizer("forward", 0) - weight_quantizer = self.get_quantizer("forward", 1) - grad_output_quantizer = self.get_quantizer("backward", 0) + def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: + super().reset_recipe_state(recipe=recipe) - # Specify required tensor formats + # Input/grad output quantizers use internal tensors + input_quantizer = self.get_quantizer("forward", 0) + grad_output_quantizer = self.get_quantizer("backward", 0) + if input_quantizer is not None: input_quantizer.internal = True - weight_quantizer.internal = True + if grad_output_quantizer is not None: grad_output_quantizer.internal = True - # Recipe-specific configuration - if recipe.float8_current_scaling(): - if any( - not isinstance(q, Float8CurrentScalingQuantizer) - for q in (input_quantizer, weight_quantizer, grad_output_quantizer) - ): - raise RuntimeError( - "FP8 current-scaling recipe is enabled, " - f"but input quantizer is {input_quantizer.__class__.__name__}, " - f"weight quantizer is {weight_quantizer.__class__.__name__}, " - f"grad output quantizer is {grad_output_quantizer.__class__.__name__}" - ) - input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - if self.sequence_parallel and self.tensor_parallel_mode == "column": - input_quantizer.with_amax_reduction = True - input_quantizer.amax_reduction_group = self.tensor_parallel_group - if self.sequence_parallel and self.tensor_parallel_mode == "row": - grad_output_quantizer.with_amax_reduction = True - grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group - - # Make sure weight tensor has correct quantizer - # Note: Quantizer might have changed if quantization - # recipe changed - if isinstance( - weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) - ) and isinstance(weight, Float8TensorBase): - weight._quantizer = weight_quantizer + # Handle weight quantizer + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + weight_quantizer = self.get_quantizer("forward", 1) + if weight_quantizer is None: + pass + elif is_quantized_tensor(getattr(self, "weight", None)): + # Make sure weight param has correct quantizer + weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) + weight_quantizer.internal = False + self.weight.update_quantizer(weight_quantizer.copy()) + else: + # Use internal tensors if quantized weights will not be + # exposed externally + weight_quantizer.internal = ( + not FP8GlobalStateManager.with_fp8_parameters() + and not getattr(self, "_with_quantized_weight", False) + ) @staticmethod def _functional_forward( input: torch.Tensor, # pylint: disable=redefined-builtin weight: torch.Tensor, *, + alpha: float = 1.0, bias: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, # pylint: disable=unused-argument dtype: Optional[torch.dtype] = None, out: Optional[torch.Tensor] = None, + beta: Optional[float] = None, accumulate_into_out: bool = False, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, @@ -388,6 +382,8 @@ def _functional_forward( Input tensor weight: torch.Tensor Weight tensor + alpha: float, default = 1.0 + Scaling factor applied to the result of the GEMM bias: torch.Tensor, optional Bias tensor device: torch.device, default = default CUDA device @@ -396,6 +392,8 @@ def _functional_forward( Tensor datatype out: torch.Tensor, optional Output tensor + beta: float, optional + Scaling factor applied to original value of out when accumulating into it accumulate_into_out: bool, default = `False` Add result to output tensor instead of overwriting tensor_parallel_mode: {`None`, "column", "row"}, default = `None` @@ -441,7 +439,7 @@ def _functional_forward( if dtype is None: if out is not None and isinstance(out, torch.Tensor): dtype = out.dtype - elif weight is not None and isinstance(out, torch.Tensor): + elif weight is not None and isinstance(weight, torch.Tensor): dtype = weight.dtype else: raise ValueError( @@ -516,18 +514,11 @@ def _functional_forward( raise ValueError("Output tensor is quantized, but quantizer was not provided") else: output_quantizer = None - if isinstance(output_quantizer, MXFP8Quantizer): - raise RuntimeError( - "Attempting to generate MXFP8 output tensor, " - "but GEMM with MXFP8 output is not supported" - ) - if isinstance(output_quantizer, Float8BlockQuantizer): - raise RuntimeError( - "Attempting to generate Float8BlockQuantized output tensor, " - "but GEMM with Float8BlockQuantized output is not supported" - ) - if output_quantizer is not None: + if not isinstance(output_quantizer, Float8Quantizer): + raise RuntimeError( + "Attempting to generate quantized output tensor with unsupported quantizer" + ) output_quantizer.set_usage(rowwise=True, columnwise=False) # Check if accumulating into output tensor @@ -552,6 +543,8 @@ def _functional_forward( get_workspace(), out_dtype=dtype, quantization_params=output_quantizer, + alpha=alpha, + beta=beta, accumulate=accumulate_into_out, out=y, bias=bias, @@ -589,13 +582,17 @@ def _functional_backward( input: Optional[torch.Tensor], # pylint: disable=redefined-builtin weight: Optional[torch.Tensor], *, + grad_input_alpha: Optional[float] = None, input_requires_grad: bool = True, + grad_weight_alpha: Optional[float] = None, weight_requires_grad: bool = True, device: Optional[torch.device] = None, # pylint: disable=unused-argument dtype: Optional[torch.dtype] = None, grad_weight: Optional[torch.Tensor] = None, + grad_weight_beta: Optional[float] = None, accumulate_into_grad_weight: bool = False, grad_input: Optional[torch.Tensor] = None, + grad_input_beta: Optional[float] = None, accumulate_into_grad_input: bool = False, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, @@ -618,8 +615,12 @@ def _functional_backward( weight: torch.Tensor, optional Weight tensor. Required to compute loss gradient w.r.t. input. + grad_input_alpha: float, optional + Scaling factor applied to the result of the dgrad GEMM input_requires_grad: bool Whether to compute loss gradient w.r.t. input tensor + grad_weight_alpha: float, optional + Scaling factor applied to the result of the wgrad GEMM weight_requires_grad: bool Whether to compute loss gradient w.r.t. weight tensor device: torch.device, default = default CUDA device @@ -628,10 +629,14 @@ def _functional_backward( Tensor datatype grad_weight: torch.Tensor, optional Loss gradient w.r.t. weight tensor + grad_weight_beta: float, optional + Scaling factor applied to original value of grad_weight when accumulating into it accumulate_into_grad_weight: bool, default = `False` Add result to weight grad instead of overwriting grad_input: torch.Tensor, optional Loss gradient w.r.t. input tensor + grad_input_beta: float, optional + Scaling factor applied to original value of grad_input when accumulating into it accumulate_into_grad_input: bool, default = `False` Add result to input grad instead of overwriting tensor_parallel_mode: {`None`, "column", "row"}, default = `None` @@ -801,11 +806,12 @@ def _functional_backward( ) else: grad_input_quantizer = None - if isinstance(grad_input_quantizer, MXFP8Quantizer): - raise RuntimeError( - "Attempting to generate MXFP8 grad input tensor, " - "but GEMM with MXFP8 output is not supported" - ) + if grad_input_quantizer is not None: + if not isinstance(grad_input_quantizer, Float8Quantizer): + raise RuntimeError( + "Attempting to generate quantized grad input tensor " + "with unsupported quantizer" + ) # Check if accumulating into grad input tensor if accumulate_into_grad_input: @@ -827,6 +833,8 @@ def _functional_backward( get_workspace(), out_dtype=dtype, quantization_params=grad_input_quantizer, + alpha=grad_input_alpha, + beta=grad_input_beta, accumulate=accumulate_into_grad_input, layout="NN", out=dx, @@ -877,6 +885,8 @@ def _functional_backward( dy, get_workspace(), out_dtype=dw_dtype, + alpha=grad_weight_alpha, + beta=grad_weight_beta, accumulate=accumulate_into_grad_weight, layout="NT", out=dw, @@ -894,7 +904,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: @@ -903,27 +913,34 @@ def op_forward( weight_requires_grad = ctx.requires_grad and self.weight.requires_grad # FP8 metadata + input_quantizer = self.get_quantizer("forward", 0) + weight_quantizer = self.get_quantizer("forward", 1) + output_quantizer = next_op_input_quantizer + grad_output_quantizer = self.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - input_quantizer = None - weight_quantizer = None - output_quantizer = None - grad_output_quantizer = None - grad_input_quantizer = None if with_quantized_compute: - - # Get quantizers - input_quantizer = self.get_quantizer("forward", 0) - weight_quantizer = self.get_quantizer("forward", 1) - output_quantizer = next_op_input_quantizer - grad_output_quantizer = self.get_quantizer("backward", 0) - grad_input_quantizer = prev_op_grad_input_quantizer - # Configure quantizers # Note: We cache the quantized input for backward pass, # but discard the quantized weights. input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) weight_quantizer.set_usage(rowwise=True, columnwise=False) + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + if self.sequence_parallel and self.tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + if self.sequence_parallel and self.tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + # Get autocast dtype if needed if torch.is_autocast_enabled(): dtype = torch.get_autocast_dtype("cuda") @@ -947,15 +964,18 @@ def op_forward( ) # Save state for backward pass - ctx.save_for_backward(x_local, w) - ctx.with_quantized_compute = with_quantized_compute - ctx.input_quantizer = input_quantizer - ctx.weight_quantizer = weight_quantizer - ctx.grad_output_quantizer = grad_output_quantizer - ctx.grad_input_quantizer = grad_input_quantizer - ctx.dtype = dtype - ctx.input_requires_grad = input_requires_grad - ctx.weight_requires_grad = weight_requires_grad + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) + ctx.save_for_backward(x_local, w) + ctx.with_quantized_compute = with_quantized_compute + ctx.input_quantizer = input_quantizer + ctx.weight_quantizer = weight_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.grad_input_quantizer = grad_input_quantizer + ctx.dtype = dtype + ctx.input_requires_grad = input_requires_grad + ctx.weight_requires_grad = weight_requires_grad return output @@ -968,20 +988,22 @@ def op_backward( # Saved tensors from forward pass (x_local, w) = ctx.saved_tensors - # wgrad fusion + # Megatron-LM wgrad fusion + # Note: Get grad tensor from param so we can accumulate + # directly into it. accumulate_into_main_grad = self._accumulate_into_main_grad grad_weight = None if ctx.weight_requires_grad and accumulate_into_main_grad: - if hasattr(self.weight, "__fsdp_param__"): - self.weight.main_grad = self.weight.get_main_grad() - - if not hasattr(self.weight, "main_grad"): + weight_param = self.weight + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " "accumulate_into_main_grad=True, " "but weight parameter does not have main_grad attribute" ) - grad_weight = self.weight.main_grad.detach() + grad_weight = weight_param.main_grad.detach() else: accumulate_into_main_grad = False @@ -1008,6 +1030,17 @@ def op_backward( # Clear input tensor if possible clear_tensor_data(x_local) + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. if accumulate_into_main_grad: grad_weight = None + weight_param = self.weight + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + return grad_input, [grad_weight] diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index a985601e2..5ec0d2ce5 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -10,15 +10,8 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) -from ...utils import ( - canonicalize_device, - canonicalize_dtype, -) -from ...fp8 import FP8GlobalStateManager +from ..op import BasicOperation, OperationContext +from ...utils import canonicalize_device, canonicalize_dtype from ...tensor import Quantizer @@ -114,8 +107,8 @@ def reset_parameters(self) -> None: bias = torch.nn.Parameter(bias) self.bias = bias - def pre_first_forward(self, *args, **kwargs) -> None: - super().pre_first_forward(*args, **kwargs) + def pre_first_fuser_forward(self) -> None: + super().pre_first_fuser_forward() if self.bias.device.type == "meta": self.reset_parameters() @@ -123,24 +116,14 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: x = input_ b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) - # Check if backward pass is needed - requires_grad = ctx.requires_grad - - # Check if previous op quantizes its output's gradient - grad_input_quantizer = None - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - if with_quantized_compute: - grad_input_quantizer = prev_op_grad_input_quantizer - - if requires_grad: - ctx.with_quantized_compute = with_quantized_compute - ctx.grad_input_quantizer = grad_input_quantizer + if ctx.requires_grad: + ctx.grad_input_quantizer = prev_op_grad_output_quantizer return x + b @@ -152,10 +135,10 @@ def op_backward( dy = grad_output if dy.dim() > 1: quantizer = ctx.grad_input_quantizer - if ctx.with_quantized_compute and quantizer is not None: - db, dy = tex.bgrad_quantize(dy, quantizer) - else: + if quantizer is None: db = dy.sum(tuple(range(dy.dim() - 1))) + else: + db, dy = tex.bgrad_quantize(dy, quantizer) else: db = dy return dy, (db,) diff --git a/transformer_engine/pytorch/ops/basic/constant_scale.py b/transformer_engine/pytorch/ops/basic/constant_scale.py new file mode 100644 index 000000000..4de70c0e9 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/constant_scale.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for constant scaling.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) +from ...tensor import Quantizer + + +class ConstantScale(BasicOperation): + """Multiply by a constant""" + + def __init__(self, scale: float) -> None: + super().__init__() + self.scale = scale + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + return input_ * self.scale + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + return grad_output * self.scale, () diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py new file mode 100644 index 000000000..30ccf5ebc --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for dropout.""" + +from __future__ import annotations +from typing import Optional + +import torch +import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...tensor import Quantizer +from ...tensor._internal.float8_tensor_base import Float8TensorBase +from .._common import maybe_autocast_dtype, maybe_dequantize +from ..op import BasicOperation, OperationContext + + +class Dropout(BasicOperation): + """Randomly zero out tensor entries during training + + During training, tensor entries are randomly set to zero with + probability :math:`p` and remaining entries are scaled by + :math:`1/(1-p)`. + + """ + + def __init__(self, p: float) -> None: + super().__init__() + self.dropout_probability: float = p + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + + # Output dtype + dtype = maybe_autocast_dtype(default_dtype=input_.dtype) + + # Choose implementation + impl = None + if not self.training: + impl = "evaluation" + elif input_.numel() % 16 == 0 and dtype in (torch.float16, torch.bfloat16): + impl = "fused" + else: + impl = "unfused" + + # Perform dropout + out: torch.Tensor + mask: Optional[torch.Tensor] = None + if impl == "evaluation": + out = input_ + elif impl == "fused": + x = input_ + if not isinstance(x, Float8TensorBase): + x = maybe_dequantize(x, dtype=dtype) + out, mask = tex.dropout_fwd(x, self.dropout_probability) + elif impl == "unfused": + x = maybe_dequantize(input_, dtype=dtype) + keep_prob = 1 - self.dropout_probability + mask = torch.empty_like(x) + mask.bernoulli_(keep_prob) + mask *= 1 / keep_prob + out = x * mask + else: + raise ValueError(f"Unsupported forward implementation {impl}") + + # Save context for backward + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(mask) + ctx.save_for_backward(mask) + ctx.impl = impl + ctx.dropout_probability = self.dropout_probability + ctx.dtype = dtype + + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (mask,) = ctx.saved_tensors + + # Perform dropout backward pass + grad_input: torch.Tensor + if ctx.impl == "evaluation": + grad_input = grad_output + elif ctx.impl == "fused": + dy = maybe_dequantize(grad_output, dtype=ctx.dtype) + grad_input = tex.dropout_bwd(dy, mask, ctx.dropout_probability) + elif ctx.impl == "unfused": + dy = maybe_dequantize(grad_output, dtype=ctx.dtype) + grad_input = dy * mask + else: + raise ValueError(f"Unsupported backward implementation {ctx.impl}") + + return grad_input, () diff --git a/transformer_engine/pytorch/ops/basic/identity.py b/transformer_engine/pytorch/ops/basic/identity.py index 3161e77c7..788b3aac8 100644 --- a/transformer_engine/pytorch/ops/basic/identity.py +++ b/transformer_engine/pytorch/ops/basic/identity.py @@ -23,7 +23,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: return input_ diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py index d8196c1bd..440fee34d 100644 --- a/transformer_engine/pytorch/ops/basic/l2normalization.py +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -6,12 +6,12 @@ from __future__ import annotations from typing import Optional +import os import torch -from ...utils import clear_tensor_data -from .._common import maybe_dequantize -from ..op import BasicOperation, OperationContext +from ... import torch_version +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...jit import ( l2normalization_fused, l2normalization_fwd_fused, @@ -20,6 +20,9 @@ warmup_jit_l2normalization_all_dtypes, ) from ...tensor import Quantizer +from ...utils import clear_tensor_data +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize class L2Normalization(BasicOperation): @@ -60,7 +63,11 @@ def __init__( # JIT warmup for L2Normalization fused operations if seq_length and micro_batch_size: - if torch.cuda.is_available(): + if ( + torch.cuda.is_available() + and torch_version() >= (2, 0, 0) + and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))) + ): set_jit_fusion_options() # For L2Normalization, we don't know the hidden size until forward pass, # but we can warm up with common sizes. For QK normalization, this will be @@ -74,7 +81,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: # Use input directly - torch.compile can handle multi-dimensional tensors @@ -86,7 +93,7 @@ def op_forward( # Compute L2 normalization using fused implementation # L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps) if requires_grad: - # Training: use version that returns both output and intermediate values + # Training: use version that returns output and intermediate values for backward pass y, rsqrt_norm = l2normalization_fwd_fused(x, self.eps) else: # Inference: use lightweight version that only returns output @@ -95,6 +102,8 @@ def op_forward( # Save state for backward pass if requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x, rsqrt_norm) ctx.save_for_backward(x, rsqrt_norm) return y @@ -110,7 +119,7 @@ def op_backward( dy = maybe_dequantize(grad_output) - # Compute L2 norm backward pass using fused implementation + # Compute L2 norm backward pass using fused implementation - recalculates l2_norm_squared_eps dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps) # Clear saved tensors if possible diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 26c39909e..d429c4fa4 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -18,8 +18,10 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton -from ...fp8 import FP8GlobalStateManager from ...constants import TE_DType +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...export import is_in_onnx_export_mode +from ...tensor import Quantizer from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -28,8 +30,6 @@ ) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, maybe_dequantize -from ...export import is_in_onnx_export_mode -from ...tensor import Quantizer class LayerNorm(BasicOperation): @@ -173,8 +173,8 @@ def reset_parameters(self) -> None: self.weight = weight self.bias = bias - def pre_first_forward(self, *args, **kwargs) -> None: - super().pre_first_forward(*args, **kwargs) + def pre_first_fuser_forward(self) -> None: + super().pre_first_fuser_forward() if self.weight.device.type == "meta" or self.bias.device.type == "meta": self.reset_parameters() @@ -182,7 +182,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: if is_in_onnx_export_mode(): @@ -205,17 +205,8 @@ def op_forward( w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) b = maybe_dequantize(self.bias, dtype).view((inner_dim,)) - # Check if backward pass is needed - requires_grad = ctx.requires_grad - - # Check if output is quantized - output_quantizer = None - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - if with_quantized_compute: - output_quantizer = next_op_input_quantizer - # Compute layer norm - sm_margin = self._sm_margins["forward" if requires_grad else "inference"] + sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"] use_layernorm_triton = bool( int(os.environ.get('NVTE_USE_LAYERNORM_TRITON', '0')) ) and IS_HIP_EXTENSION layernorm_fwd_func = te_layernorm_fwd_triton if use_layernorm_triton else layernorm_fwd y, means, rstdevs = layernorm_fwd_func( @@ -224,14 +215,16 @@ def op_forward( b, self.eps, None, - output_quantizer, + next_op_input_quantizer, TE_DType[dtype], sm_margin, self.zero_centered_gamma, ) # Save state for backward pass - if requires_grad: + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/make_extra_output.py b/transformer_engine/pytorch/ops/basic/make_extra_output.py index 81b581ae2..34228affc 100644 --- a/transformer_engine/pytorch/ops/basic/make_extra_output.py +++ b/transformer_engine/pytorch/ops/basic/make_extra_output.py @@ -22,14 +22,20 @@ class MakeExtraOutput(BasicOperation): If this operation is included in the operation fuser, then the operation fuser will return the intermediate tensor as an extra - tensor output. In the backward pass, the gradient is directly - accumulated into the gradient w.r.t. the extra output. + tensor output. - This operation is considered an advanced feature and most users - are discouraged from using it. In-place operations break some - autograd assumptions and they can result in subtle, esoteric bugs. + In the backward pass, the gradient may be directly + accumulated into the gradient w.r.t. the extra output. This is + controlled by the in_place kwarg. Currently, the BackwardLinearAdd + fusion is able to happen only with in_place=True. - Compare to `AddInPlace`, which does a similar operation in the + Using this operation with in_place=True is + considered an advanced feature. Most users are discouraged + from enabling it in-place gradient accumulation, as in-place + operations break some autograd assumptions and they can result + in subtle, esoteric bugs. + + Compare to `AddExtraInput`, which does a similar operation in the backward pass. """ @@ -37,6 +43,10 @@ class MakeExtraOutput(BasicOperation): # Operation expects buffer for output tensor num_extra_outputs: int = 1 + def __init__(self, *, in_place: bool = False): + super().__init__() + self._in_place: bool = in_place + def op_forward(self, *args, **kwargs) -> None: raise RuntimeError( "{self.__class__.__name__} operation has " @@ -59,7 +69,7 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: @@ -76,6 +86,10 @@ def fuser_backward( Iterable[Iterable[Optional[torch.Tensor]]], Iterable[Iterable[Optional[torch.Tensor]]], ]: - grad_input = basic_op_grad_extra_outputs[0][0] - grad_input += grad_output + grad_extra_output = basic_op_grad_extra_outputs[0][0] + if self._in_place: + grad_extra_output += grad_output + grad_input = grad_extra_output + else: + grad_input = grad_extra_output + grad_output return grad_input, [()], [()] diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 005e9fd8d..dcfc3c4f7 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -50,7 +50,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: @@ -64,7 +64,8 @@ def op_forward( if quantize_forward and not is_quantized_tensor(out): out = self.get_quantizer("forward", 0)(out) - ctx.quantize_backward = quantize_backward + if ctx.requires_grad: + ctx.quantize_backward = quantize_backward return out def op_backward( diff --git a/transformer_engine/pytorch/ops/basic/reduce_scatter.py b/transformer_engine/pytorch/ops/basic/reduce_scatter.py index 1238b0879..e0017853f 100644 --- a/transformer_engine/pytorch/ops/basic/reduce_scatter.py +++ b/transformer_engine/pytorch/ops/basic/reduce_scatter.py @@ -40,7 +40,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: diff --git a/transformer_engine/pytorch/ops/basic/reshape.py b/transformer_engine/pytorch/ops/basic/reshape.py index 8d8b75ff0..50af9fcff 100644 --- a/transformer_engine/pytorch/ops/basic/reshape.py +++ b/transformer_engine/pytorch/ops/basic/reshape.py @@ -38,10 +38,11 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: - ctx.input_shape = input_.size() + if ctx.requires_grad: + ctx.input_shape = input_.size() return input_.reshape(*self._shape) def op_backward( diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 250769ec8..bbe805fe9 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -21,8 +21,10 @@ te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton ) -from ...fp8 import FP8GlobalStateManager from ...constants import TE_DType +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...export import is_in_onnx_export_mode +from ...tensor import Quantizer from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -31,8 +33,6 @@ ) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, maybe_dequantize -from ...export import is_in_onnx_export_mode -from ...tensor import Quantizer class RMSNorm(BasicOperation): @@ -160,8 +160,8 @@ def reset_parameters(self) -> None: weight = torch.nn.Parameter(weight) self.weight = weight - def pre_first_forward(self, *args, **kwargs) -> None: - super().pre_first_forward(*args, **kwargs) + def pre_first_fuser_forward(self) -> None: + super().pre_first_fuser_forward() if self.weight.device.type == "meta": self.reset_parameters() @@ -169,7 +169,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: if is_in_onnx_export_mode(): @@ -191,17 +191,8 @@ def op_forward( x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim)) w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) - # Check if backward pass is needed - requires_grad = ctx.requires_grad - - # Check if output is quantized - output_quantizer = None - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - if with_quantized_compute: - output_quantizer = next_op_input_quantizer - # Compute RMSNorm - sm_margin = self._sm_margins["forward" if requires_grad else "inference"] + sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"] # Compute RMSNorm forward pass rmsnorm_fwd_func = te_rmsnorm_fwd_triton if self.use_rmsnorm_triton else rmsnorm_fwd y, _, rstdevs = rmsnorm_fwd_func( @@ -209,14 +200,16 @@ def op_forward( w, self.eps, None, - output_quantizer, + next_op_input_quantizer, TE_DType[dtype], sm_margin, self.zero_centered_gamma, ) # Save state for backward pass - if requires_grad: + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x, rstdevs) ctx.save_for_backward(x, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 29d3c50cd..d14454dc0 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -6,14 +6,22 @@ """Compound tensor operation supported by the operation fuser.""" -from .backward_bias_activation import ( - BackwardBiasActivation, - fuse_backward_bias_activation, +from .backward_activation_bias import ( + BackwardActivationBias, + fuse_backward_activation_bias, +) +from .backward_add_rmsnorm import ( + BackwardAddRMSNorm, + fuse_backward_add_rmsnorm, ) from .backward_linear_add import ( BackwardLinearAdd, fuse_backward_linear_add, ) +from .backward_linear_scale import ( + BackwardLinearScale, + fuse_backward_linear_scale, +) from .forward_linear_bias_activation import ( ForwardLinearBiasActivation, fuse_forward_linear_bias_activation, @@ -22,6 +30,10 @@ ForwardLinearBiasAdd, fuse_forward_linear_bias_add, ) +from .forward_linear_scale_add import ( + ForwardLinearScaleAdd, + fuse_forward_linear_scale_add, +) from torch.utils.cpp_extension import IS_HIP_EXTENSION if not IS_HIP_EXTENSION: from .userbuffers_backward_linear import ( diff --git a/transformer_engine/pytorch/ops/fused/backward_bias_activation.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py similarity index 86% rename from transformer_engine/pytorch/ops/fused/backward_bias_activation.py rename to transformer_engine/pytorch/ops/fused/backward_activation_bias.py index f4b7b9ec3..40510c856 100644 --- a/transformer_engine/pytorch/ops/fused/backward_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Fused backward dbias + dact + quantize.""" +"""Fused backward dact + dbias + quantize.""" from __future__ import annotations from typing import Optional @@ -29,8 +29,8 @@ _fusible_activations = tuple(_fused_activations.keys()) -class BackwardBiasActivation(FusedOperation): - """Fused backward dbias + dact + quantize +class BackwardActivationBias(FusedOperation): + """Fused backward dact + dbias + quantize Uses the next operation's input quantizer. @@ -66,15 +66,10 @@ def fuser_backward( dy = maybe_dequantize(grad_output.contiguous(), act_input.dtype) # Get previous op quantizer - if not bias_op_ctx.with_quantized_compute: - raise RuntimeError( - "BackwardBiasActivation requires quantized compute, " - "but Bias context has it disabled" - ) quantizer = bias_op_ctx.grad_input_quantizer if quantizer is None: raise RuntimeError( - "BackwardBiasActivation requires previous op's grad output quantizer, " + "BackwardActivationBias requires previous op's grad output quantizer, " "but Bias context has no quantizer" ) @@ -87,11 +82,11 @@ def fuser_backward( return dx, [(), (db,)], [(), ()] -def fuse_backward_bias_activation( +def fuse_backward_activation_bias( ops: list[tuple[FusibleOperation, list[int]]], recipe: Optional[Recipe], ) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward dbias + dact + quantize + """Fused backward dact + dbias + quantize Parameters ---------- @@ -109,7 +104,7 @@ def fuse_backward_bias_activation( """ # Check if recipe supports bias activation fusion - if recipe is None or not (recipe.delayed() or recipe.mxfp8()): + if recipe is None: return ops # Scan through ops, fusing if possible @@ -138,7 +133,7 @@ def fuse_backward_bias_activation( ops = ops[1:] # Replace window with fused op - op = BackwardBiasActivation( + op = BackwardActivationBias( activation=window[0][0], bias=window[1][0], ) diff --git a/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py new file mode 100644 index 000000000..54a23395a --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py @@ -0,0 +1,133 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused backward RMNorm + add.""" + +from __future__ import annotations +from typing import Optional +import math + +import torch + +import transformer_engine_torch as tex +from transformer_engine.pytorch.ops.basic import MakeExtraOutput, RMSNorm + +from transformer_engine.pytorch.ops.op import ( + FusedOperation, + FusibleOperation, + OperationContext, +) +from ...utils import clear_tensor_data +from .._common import maybe_dequantize + + +class BackwardAddRMSNorm(FusedOperation): + """Fused backward RMNorm + add""" + + def __init__(self, *, add: MakeExtraOutput, rmsnorm: RMSNorm): + super().__init__((add, rmsnorm)) + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + + # Get basic operations + rmsnorm_op = self.basic_ops[1] + rmsnorm_op_ctx = basic_op_ctxs[0] + + # Saved tensors from forward pass + x, rstdevs = rmsnorm_op_ctx.saved_tensors + + # Tensor dims + weight_dims = rmsnorm_op.weight.size() + inner_dim = math.prod(weight_dims) + + # Check input tensors + dtype = rmsnorm_op_ctx.dtype + extra_grad = basic_op_grad_extra_outputs[1][0] + dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size()) + w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,)) + add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size()) + + # Compute RMSNorm backward pass + dx, dw = tex.rmsnorm_bwd_add( + dy, + x, + add, + rstdevs, + w, + rmsnorm_op._sm_margins["backward"], + rmsnorm_op.zero_centered_gamma, + ) + + # Clear saved tensors if possible + clear_tensor_data(x) + clear_tensor_data(rstdevs) + + # Reshape results + grad_input = dx.view(grad_output.size()) + grad_weight = dw.view(weight_dims) + + return grad_input, [(grad_weight,), ()], [(), ()] + + +def fuse_backward_add_rmsnorm( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Fused backward RMNorm + add + + Parameters + ---------- + ops: list of tuples + Backward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated backward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while len(ops) >= 2: + out.extend(window) + + # Check if first op is linear + window, ops = ops[:1], ops[1:] + op, _ = window[0] + if not isinstance(op, RMSNorm): + continue + + # Check if second op is "make extra output" + op, _ = ops[0] + if not isinstance(op, MakeExtraOutput): + continue + if op._in_place: + continue + window.extend(ops[:1]) + ops = ops[1:] + + # Replace window with fused op + op = BackwardAddRMSNorm( + rmsnorm=window[0][0], + add=window[1][0], + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + out.extend(ops) + return out diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 54ddfaa5c..845ba262a 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -9,13 +9,10 @@ import torch -from transformer_engine.pytorch.ops.basic import BasicLinear, MakeExtraOutput -from transformer_engine.pytorch.ops.op import ( - FusedOperation, - FusibleOperation, - OperationContext, -) +from ...module.base import get_dummy_wgrad from ...utils import clear_tensor_data +from ..basic import BasicLinear, MakeExtraOutput +from ..op import FusedOperation, FusibleOperation, OperationContext class BackwardLinearAdd(FusedOperation): @@ -29,10 +26,10 @@ class BackwardLinearAdd(FusedOperation): def __init__( self, *, - linear: BasicLinear, backward_add: MakeExtraOutput, + linear: BasicLinear, ) -> None: - super().__init__((linear, backward_add)) + super().__init__((backward_add, linear)) def fuser_backward( self, @@ -47,26 +44,28 @@ def fuser_backward( ]: # Get basic operations - linear_op = self.basic_ops[0] + linear_op = self.basic_ops[1] linear_op_ctx = basic_op_ctxs[0] # Saved tensors from forward pass (x_local, w) = linear_op_ctx.saved_tensors - # wgrad fusion + # Megatron-LM wgrad fusion + # Note: Get grad tensor from param so we can accumulate + # directly into it. accumulate_into_main_grad = linear_op._accumulate_into_main_grad grad_weight = None if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: - if hasattr(linear_op.weight, "__fsdp_param__"): - linear_op.weight.main_grad = linear_op.weight.get_main_grad() - - if not hasattr(linear_op.weight, "main_grad"): + weight_param = linear_op.weight + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " "accumulate_into_main_grad=True, " "but weight parameter does not have main_grad attribute" ) - grad_weight = linear_op.weight.main_grad.detach() + grad_weight = weight_param.main_grad.detach() else: accumulate_into_main_grad = False @@ -92,12 +91,23 @@ def fuser_backward( grad_output_quantizer=linear_op_ctx.grad_output_quantizer, grad_input_quantizer=linear_op_ctx.grad_input_quantizer, ) - if accumulate_into_main_grad: - grad_weight = None # Clear input tensor if possible clear_tensor_data(x_local) + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. + if accumulate_into_main_grad: + grad_weight = None + weight_param = linear_op.weight + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + return grad_input, [(grad_weight,), ()], [(), ()] @@ -139,6 +149,8 @@ def fuse_backward_linear_add( op, _ = ops[0] if not isinstance(op, MakeExtraOutput): continue + if not op._in_place: + continue window.extend(ops[:1]) ops = ops[1:] diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py new file mode 100644 index 000000000..a9595d516 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -0,0 +1,165 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused backward dgrad GEMM + scale.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from ...module.base import get_dummy_wgrad +from ...utils import clear_tensor_data +from ..basic import BasicLinear, ConstantScale +from ..op import FusedOperation, FusibleOperation, OperationContext + + +class BackwardLinearScale(FusedOperation): + """Fused backward dgrad GEMM + scale + + Column tensor parallelism is not supported since that requires + communication immediately after the dgrad GEMM. + + """ + + def __init__( + self, + *, + scale: ConstantScale, + linear: BasicLinear, + ) -> None: + super().__init__((linear, scale)) + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + + # Get basic operations + linear_op = self.basic_ops[0] + linear_op_ctx = basic_op_ctxs[1] + scale_op = self.basic_ops[1] + + # Saved tensors from forward pass + (x_local, w) = linear_op_ctx.saved_tensors + + # Megatron-LM wgrad fusion + # Note: Get grad tensor from param so we can accumulate + # directly into it. + accumulate_into_main_grad = linear_op._accumulate_into_main_grad + grad_weight = None + if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: + weight_param = linear_op.weight + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + if not hasattr(weight_param, "main_grad"): + raise RuntimeError( + "BasicLinear op is configured with " + "accumulate_into_main_grad=True, " + "but weight parameter does not have main_grad attribute" + ) + grad_weight = weight_param.main_grad.detach() + else: + accumulate_into_main_grad = False + + # Linear backward pass + grad_input, grad_weight = BasicLinear._functional_backward( + grad_output=grad_output, + input=x_local, + weight=w, + input_requires_grad=linear_op_ctx.input_requires_grad, + grad_input_alpha=scale_op.scale, + weight_requires_grad=linear_op_ctx.weight_requires_grad, + grad_weight_alpha=scale_op.scale, + dtype=linear_op_ctx.dtype, + grad_weight=grad_weight, + accumulate_into_grad_weight=accumulate_into_main_grad, + tensor_parallel_mode=linear_op.tensor_parallel_mode, + tensor_parallel_group=linear_op.tensor_parallel_group, + sequence_parallel=linear_op.sequence_parallel, + with_quantized_compute=linear_op_ctx.with_quantized_compute, + input_quantizer=linear_op_ctx.input_quantizer, + weight_quantizer=linear_op_ctx.weight_quantizer, + grad_output_quantizer=linear_op_ctx.grad_output_quantizer, + grad_input_quantizer=linear_op_ctx.grad_input_quantizer, + ) + + # Clear input tensor if possible + clear_tensor_data(x_local) + + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. + if accumulate_into_main_grad: + grad_weight = None + weight_param = linear_op.weight + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + + return grad_input, [(), (grad_weight,)], [(), ()] + + +def fuse_backward_linear_scale( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Fused backward dgrad GEMM + constant scale + + Parameters + ---------- + ops: list of tuples + Backward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated backward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while len(ops) >= 2: + out.extend(window) + + # Check if first op is constant scale + window, ops = ops[:1], ops[1:] + op, _ = window[0] + if not isinstance(op, ConstantScale): + continue + + # Check if second op is linear + op, _ = ops[0] + if not isinstance(op, BasicLinear): + continue + if op.tensor_parallel_mode == "column": + # Column tensor-parallelism requires communication after the dgrad GEMM + continue + window.extend(ops[:1]) + ops = ops[1:] + + # Replace window with fused op + op = BackwardLinearScale( + scale=window[0][0], + linear=window[1][0], + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + out.extend(ops) + return out diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 5d1223bd8..02bcfee0a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -10,14 +10,11 @@ import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.ops.basic import BasicLinear, Bias -from transformer_engine.pytorch.ops.op import ( - FusedOperation, - FusibleOperation, - OperationContext, -) +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...fp8 import FP8GlobalStateManager from ...tensor import Quantizer +from ..basic import BasicLinear, Bias +from ..op import FusedOperation, FusibleOperation, OperationContext class ForwardLinearBiasActivation(FusedOperation): @@ -59,7 +56,7 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: @@ -89,18 +86,12 @@ def fuser_forward( weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad # FP8 metadata + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) + output_quantizer = next_op_input_quantizer + grad_output_quantizer = linear_op.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - input_quantizer = None - weight_quantizer = None - output_quantizer = None - grad_output_quantizer = None - grad_input_quantizer = None - if with_quantized_compute: - input_quantizer = linear_op.get_quantizer("forward", 0) - weight_quantizer = linear_op.get_quantizer("forward", 1) - output_quantizer = next_op_input_quantizer - grad_output_quantizer = linear_op.get_quantizer("backward", 0) - grad_input_quantizer = prev_op_grad_input_quantizer # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -126,18 +117,20 @@ def fuser_forward( ) # Save state for backward pass - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute - linear_op_ctx.input_quantizer = input_quantizer - linear_op_ctx.weight_quantizer = weight_quantizer - linear_op_ctx.grad_output_quantizer = grad_output_quantizer - linear_op_ctx.grad_input_quantizer = grad_input_quantizer - linear_op_ctx.dtype = dtype - linear_op_ctx.input_requires_grad = input_requires_grad - linear_op_ctx.weight_requires_grad = weight_requires_grad - if bias_op is not None: - bias_op_ctx.with_quantized_compute = with_quantized_compute - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() + if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) + linear_op_ctx.save_for_backward(x_local, w) + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer + linear_op_ctx.dtype = dtype + linear_op_ctx.input_requires_grad = input_requires_grad + linear_op_ctx.weight_requires_grad = weight_requires_grad + if bias_op is not None and bias_op_ctx.requires_grad: + bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 5055bc60a..15cc081c1 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -10,14 +10,11 @@ import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias -from transformer_engine.pytorch.ops.op import ( - FusedOperation, - FusibleOperation, - OperationContext, -) -from transformer_engine.pytorch.tensor import Quantizer +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...fp8 import FP8GlobalStateManager +from ...tensor import Quantizer +from ..basic import AddExtraInput, BasicLinear, Bias +from ..op import FusedOperation, FusibleOperation, OperationContext class ForwardLinearBiasAdd(FusedOperation): @@ -33,7 +30,7 @@ def __init__( *, linear: BasicLinear, bias: Optional[Bias], - add: AddInPlace, + add: AddExtraInput, ) -> None: # Basic operations that comprise this fused operation @@ -57,7 +54,7 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: @@ -83,17 +80,12 @@ def fuser_forward( weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad # FP8 metadata - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - input_quantizer = None - weight_quantizer = None + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = None - grad_output_quantizer = None - grad_input_quantizer = None - if with_quantized_compute: - input_quantizer = linear_op.get_quantizer("forward", 0) - weight_quantizer = linear_op.get_quantizer("forward", 1) - grad_output_quantizer = linear_op.get_quantizer("backward", 0) - grad_input_quantizer = prev_op_grad_input_quantizer + grad_output_quantizer = linear_op.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_output_quantizer + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -122,18 +114,20 @@ def fuser_forward( ) # Save state for backward pass - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute - linear_op_ctx.input_quantizer = input_quantizer - linear_op_ctx.weight_quantizer = weight_quantizer - linear_op_ctx.grad_output_quantizer = grad_output_quantizer - linear_op_ctx.grad_input_quantizer = grad_input_quantizer - linear_op_ctx.dtype = dtype - linear_op_ctx.input_requires_grad = input_requires_grad - linear_op_ctx.weight_requires_grad = weight_requires_grad - if bias_op is not None: - bias_op_ctx.with_quantized_compute = with_quantized_compute - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() + if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) + linear_op_ctx.save_for_backward(x_local, w) + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer + linear_op_ctx.dtype = dtype + linear_op_ctx.input_requires_grad = input_requires_grad + linear_op_ctx.weight_requires_grad = weight_requires_grad + if bias_op is not None and bias_op_ctx.requires_grad: + bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() return output, [() for _ in range(len(self.basic_ops))] @@ -184,8 +178,10 @@ def fuse_forward_linear_bias_add( continue op, _ = ops[0] - # Check if next op is add in-place - if not isinstance(op, AddInPlace): + # Check if next op is in-place add extra input + if not isinstance(op, AddExtraInput): + continue + if not op._in_place: continue add = op window.extend(ops[:1]) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py new file mode 100644 index 000000000..21190d4fc --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -0,0 +1,179 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused operation for forward GEMM + scale + add.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...fp8 import FP8GlobalStateManager +from ...tensor import Quantizer +from ..basic import AddExtraInput, BasicLinear, ConstantScale +from ..op import ( + FusedOperation, + FusibleOperation, + OperationContext, +) + + +class ForwardLinearScaleAdd(FusedOperation): + """Fused forward GEMM + scale + add + + Row tensor parallelism is not supported since that requires + communication immediately after the GEMM. + + """ + + def __init__( + self, + *, + linear: BasicLinear, + scale: ConstantScale, + add: AddExtraInput, + ) -> None: + super().__init__((linear, scale, add)) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + + # Get basic operations + linear_op = self.basic_ops[0] + linear_op_ctx = basic_op_ctxs[0] + scale_op = self.basic_ops[1] + + # Check which grads are required + input_requires_grad = linear_op_ctx.requires_grad + weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad + + # FP8 metadata + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) + output_quantizer = None + grad_output_quantizer = linear_op.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_output_quantizer + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + + # Get extra input tensor for add operation + extra_input = basic_op_extra_inputs[2][0] + + # Get autocast dtype if needed + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = linear_op.weight.dtype + + # Linear forward + output, x_local, w = BasicLinear._functional_forward( + input=input_, + weight=linear_op.weight, + alpha=scale_op.scale, + dtype=dtype, + out=extra_input, + accumulate_into_out=True, + tensor_parallel_mode=linear_op.tensor_parallel_mode, + tensor_parallel_group=linear_op.tensor_parallel_group, + sequence_parallel=linear_op.sequence_parallel, + with_quantized_compute=with_quantized_compute, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + output_quantizer=output_quantizer, + input_requires_grad=input_requires_grad, + weight_requires_grad=weight_requires_grad, + ) + + # Save state for backward pass + if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) + linear_op_ctx.save_for_backward(x_local, w) + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer + linear_op_ctx.dtype = dtype + linear_op_ctx.input_requires_grad = input_requires_grad + linear_op_ctx.weight_requires_grad = weight_requires_grad + + return output, [() for _ in range(len(self.basic_ops))] + + +def fuse_forward_linear_scale_add( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Fuse forward GEMM + scale + add + + Parameters + ---------- + ops: list of tuples + Forward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while len(ops) >= 3: + out.extend(window) + + # Check if first op is linear + window, ops = ops[:1], ops[1:] + op, _ = window[0] + if not isinstance(op, BasicLinear): + continue + if op.tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after the + # GEMM + continue + linear = op + op, _ = ops[0] + + # Check if next op is constant scale + if not isinstance(op, ConstantScale): + continue + scale = op + window.extend(ops[:1]) + ops = ops[1:] + op, _ = ops[0] + + # Check if next op is in-place add extra input + if not isinstance(op, AddExtraInput): + continue + if not op._in_place: + continue + add = op + window.extend(ops[:1]) + ops = ops[1:] + + # Replace window with fused op + op = ForwardLinearScaleAdd( + linear=linear, + scale=scale, + add=add, + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + out.extend(ops) + return out diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 4fbc28482..1ecdba625 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -10,15 +10,16 @@ import torch -from transformer_engine_torch import CommOverlapType +from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_external_gemm from ...cpp_extensions import general_gemm -from ...distributed import gather_along_first_dim, get_distributed_world_size +from ...distributed import get_distributed_world_size from ...module.base import ( + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, fill_userbuffers_buffer_for_all_gather, + get_dummy_wgrad, get_ub, get_workspace, - _2X_ACC_DGRAD, - _2X_ACC_WGRAD, ) from ...tensor.quantized_tensor import Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer @@ -48,14 +49,14 @@ def __init__( # Basic operations that comprise this fused operation op_idxs = {"linear": None, "bias": None, "reduce_scatter": None} ops = [] - if reduce_scatter is not None: - op_idxs["reduce_scatter"] = len(ops) - ops.append(reduce_scatter) + op_idxs["linear"] = len(ops) + ops.append(linear) if bias is not None: op_idxs["bias"] = len(ops) ops.append(bias) - op_idxs["linear"] = len(ops) - ops.append(linear) + if reduce_scatter is not None: + op_idxs["reduce_scatter"] = len(ops) + ops.append(reduce_scatter) # Initialize base class super().__init__(ops) @@ -240,16 +241,16 @@ def _functional_backward( with_dgrad_all_gather_x = False with_wgrad_reduce_scatter_dx = False if tensor_parallel_mode == "row": - ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") + ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute) ub_type_dgrad = CommOverlapType.AG with_dgrad_all_gather_dy = True elif tensor_parallel_mode == "column": if input_requires_grad and weight_requires_grad: with_bulk_overlap = True - ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") + ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute) ub_type_dgrad = CommOverlapType.AG with_dgrad_all_gather_x = True - ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad") + ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute) ub_type_wgrad = CommOverlapType.RS with_wgrad_reduce_scatter_dx = True if ub_comm_wgrad.is_fp8_ubuf(): @@ -257,7 +258,7 @@ def _functional_backward( "Userbuffers reduce-scatter is not supported with FP8 buffers" ) else: - ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") + ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute) ub_type_dgrad = CommOverlapType.RS with_dgrad_reduce_scatter_dx = True if ub_comm_dgrad.is_fp8_ubuf(): @@ -398,26 +399,35 @@ def _functional_backward( # Initialize grad output if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer): - # UB does not support overlapping grad output + # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we # can't reuse the grad output that was gathered # for the dgrad GEMM. We work around by explicitly - # overlapping the NCCL operation with the dgrad GEMM. + # overlapping the AG operation with the dgrad GEMM. + + # Get the communication stream from the dgrad GEMM to use for the AG + dgrad_send_stream, dgrad_recv_stream = ub_comm_dgrad.get_communication_stream() + + ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute) + grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - # Get the communication stream from the dgrad GEMM and set it as the current torch stream - dgrad_comm_stream = ub_comm_dgrad.get_communication_stream() - with torch.cuda.stream(dgrad_comm_stream): - # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather - # This ensures that we don't start until all communication for the dgrad GEMM is complete - dy, dy_work = gather_along_first_dim( + + # We use the send stream to copy into the userbuffers. + # This is the same stream that we will use to access the data in the AG, + # so we dont need to add any syncs yet. + with torch.cuda.stream(dgrad_send_stream): + dy, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj_overlap_wgrad, dy_local, + grad_output_quantizer, tensor_parallel_group, - async_op=True, - quantizer=grad_output_quantizer, ) - # Synchronize with the main stream - dy_work.wait() + + # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm + bulk_overlap_ag_with_external_gemm( + ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream + ) if tensor_parallel_mode == "column": dy = dy_local @@ -495,7 +505,7 @@ def fuser_backward( # Get basic operations idx = self._op_idxs["linear"] linear_op = self.basic_ops[idx] - linear_op_ctx = basic_op_ctxs[idx] + linear_op_ctx = basic_op_ctxs[-1] bias_op = None if self._op_idxs["bias"] is not None: idx = self._op_idxs["bias"] @@ -504,20 +514,22 @@ def fuser_backward( # Saved tensors from forward pass (x_local, w) = linear_op_ctx.saved_tensors - # wgrad fusion + # Megatron-LM wgrad fusion + # Note: Get grad tensor from param so we can accumulate + # directly into it. accumulate_into_main_grad = linear_op._accumulate_into_main_grad grad_weight = None if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: - if hasattr(linear_op.weight, "__fsdp_param__"): - linear_op.weight.main_grad = linear_op.weight.get_main_grad() - - if not hasattr(linear_op.weight, "main_grad"): + weight_param = linear_op.weight + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " "accumulate_into_main_grad=True, " "but weight parameter does not have main_grad attribute" ) - grad_weight = linear_op.weight.main_grad.detach() + grad_weight = weight_param.main_grad.detach() else: accumulate_into_main_grad = False @@ -549,13 +561,25 @@ def fuser_backward( # Clear input tensor if possible clear_tensor_data(x_local) - # Return gradients - grad_params = [() for _ in range(len(self.basic_ops))] + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. if accumulate_into_main_grad: grad_weight = None + weight_param = linear_op.weight + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + + # Return gradients + grad_params = [() for _ in range(len(self.basic_ops))] grad_params[self._op_idxs["linear"]] = (grad_weight,) if bias_op is not None: grad_params[self._op_idxs["bias"]] = (grad_bias,) + grad_params.reverse() grad_extra_inputs = [() for _ in range(len(self.basic_ops))] return grad_input, grad_params, grad_extra_inputs diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 30d9cdaae..a604e57dc 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -12,6 +12,7 @@ from transformer_engine_torch import CommOverlapType from ...cpp_extensions import general_gemm +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import get_distributed_world_size from ...fp8 import FP8GlobalStateManager from ...module.base import ( @@ -182,14 +183,14 @@ def _functional_forward( if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") if output_quantizer is not None: - raise ValueError("FP8 output is not supported") + raise ValueError("Quantized output is not supported") else: input_quantizer = None weight_quantizer = None output_quantizer = None # Get Userbuffers communicator - ub_comm = get_ub(ub_comm_name + "_fprop") + ub_comm = get_ub(ub_comm_name + "_fprop", with_quantized_compute) with_ub_all_gather = tensor_parallel_mode == "column" with_ub_reduce_scatter = tensor_parallel_mode == "row" ub_type = CommOverlapType.AG if with_ub_all_gather else CommOverlapType.RS @@ -282,7 +283,7 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: @@ -307,21 +308,17 @@ def fuser_forward( weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad # Quantization metadata + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) + grad_output_quantizer = linear_op.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - input_quantizer = None - weight_quantizer = None - grad_output_quantizer = None - grad_input_quantizer = None if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): raise RuntimeError( f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})" ) - input_quantizer = linear_op.get_quantizer("forward", 0) - weight_quantizer = linear_op.get_quantizer("forward", 1) - grad_output_quantizer = linear_op.get_quantizer("backward", 0) - grad_input_quantizer = prev_op_grad_input_quantizer # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -356,19 +353,21 @@ def fuser_forward( w = extra_outputs["weight"] # Save state for backward pass - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute - linear_op_ctx.input_quantizer = input_quantizer - linear_op_ctx.weight_quantizer = weight_quantizer - linear_op_ctx.grad_output_quantizer = grad_output_quantizer - linear_op_ctx.grad_input_quantizer = grad_input_quantizer - linear_op_ctx.dtype = dtype - linear_op_ctx.input_dims = input_.size() - linear_op_ctx.input_requires_grad = input_requires_grad - linear_op_ctx.weight_requires_grad = weight_requires_grad - if bias_op is not None: - bias_op_ctx.with_quantized_compute = with_quantized_compute - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() + if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) + linear_op_ctx.save_for_backward(x_local, w) + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer + linear_op_ctx.dtype = dtype + linear_op_ctx.input_dims = input_.size() + linear_op_ctx.input_requires_grad = input_requires_grad + linear_op_ctx.weight_requires_grad = weight_requires_grad + if bias_op is not None and bias_op_ctx.requires_grad: + bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 7549cda71..df8843649 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -7,12 +7,13 @@ """Manager class for a pipeline of fusible operations.""" from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import Any, Optional +import itertools import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe, DelayedScaling from transformer_engine.pytorch.ops.op import ( BasicOperation, FusibleOperation, @@ -20,10 +21,13 @@ ) from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch.ops.fused import ( - fuse_backward_bias_activation, + fuse_backward_activation_bias, + fuse_backward_add_rmsnorm, fuse_backward_linear_add, + fuse_backward_linear_scale, fuse_forward_linear_bias_activation, fuse_forward_linear_bias_add, + fuse_forward_linear_scale_add, ) if not IS_HIP_EXTENSION: from transformer_engine.pytorch.ops.fused import ( @@ -74,8 +78,7 @@ def forward( input_: torch.Tensor, fuser: OperationFuser, basic_op_kwargs: list[dict[str, Any]], - is_grad_enabled: bool, - *params_and_extra_inputs: torch.nn.Parameter, + *params_and_extra_inputs: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass @@ -89,8 +92,6 @@ def forward( Container for the pipeline of operations to run basic_op_kwargs: list of dict Keyword arguments to BasicOperation - is_grad_enabled: bool - Should context be saved for backward *params_and_extra_inputs: torch.Tensor Other tensor inputs to include in autograd graph. Consists of parameter tensors, followed by extra operation inputs. @@ -109,10 +110,10 @@ def forward( # Mark input tensors as not deletable in backward for tensor in (input_,) + params_and_extra_inputs: - tensor.do_not_clear = True + tensor._do_not_clear = True # Unflatten list of parameters and extra tensor inputs - extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :] + extra_inputs = params_and_extra_inputs[-fuser.num_extra_inputs :] basic_op_extra_inputs = [] for op in fuser._basic_ops: xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) @@ -120,44 +121,37 @@ def forward( # Apply forward ops x = input_ - requires_grad = is_grad_enabled and x.requires_grad - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() extra_outputs = [None] * fuser._num_basic_ops for op, basic_op_idxs in fuser._forward_ops: - # Check if backward op is required - if is_grad_enabled: - if not requires_grad: - requires_grad = any(param.requires_grad for param in op.parameters()) - if not requires_grad: - requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) + # Set if backward op is required for idx in basic_op_idxs: - basic_op_ctxs[idx].requires_grad = requires_grad + basic_op_ctxs[idx].requires_grad = idx >= fuser.first_op_requiring_backward # Forward op extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] prev_op_idx = basic_op_idxs[0] - 1 prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None - prev_op_grad_input_quantizer = None - if prev_op is not None and with_quantized_compute: - prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer() + prev_op_grad_output_quantizer = None + if prev_op is not None: + prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer() next_op_idx = basic_op_idxs[-1] + 1 next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None next_op_input_quantizer = None - if next_op is not None and with_quantized_compute: + if next_op is not None: next_op_input_quantizer = next_op.get_input_quantizer() x, fused_op_extra_outputs = op.fuser_forward( [basic_op_ctxs[idx] for idx in basic_op_idxs], x, basic_op_extra_inputs=extra_inputs, - prev_op_grad_input_quantizer=prev_op_grad_input_quantizer, + prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, next_op_input_quantizer=next_op_input_quantizer, basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], ) for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): for y in ys: - y.requires_grad_(requires_grad) + y.requires_grad_(idx >= fuser.first_op_requiring_backward) extra_outputs[idx] = ys # Flatten list of extra outputs @@ -174,7 +168,7 @@ def forward( extra_outputs_flat.extend(ys) # Save context for backward pass - if is_grad_enabled: + if func_ctx is not None: # Flatten list of saved tensors to_save = [] @@ -187,24 +181,29 @@ def forward( ctx._saved_tensors_range = (range_start, range_end) # Save tensors for backward - if with_quantized_compute: - tensors_to_save, tensor_objects = prepare_for_saving(*to_save) - func_ctx.save_for_backward(*tensors_to_save) - func_ctx.tensor_objects = tensor_objects - else: - func_ctx.save_for_backward(*to_save) + tensors_to_save, tensor_objects = prepare_for_saving(*to_save) + func_ctx.save_for_backward(*tensors_to_save) + func_ctx.tensor_objects = tensor_objects + + # Whether to perform recipe update in backward pass + is_first_module = False + if fuser.first_op_requiring_backward < fuser._num_basic_ops: + is_first_module = FP8GlobalStateManager.is_first_fp8_module() # Other context func_ctx.backward_ops = fuser._backward_ops func_ctx.basic_ops = fuser._basic_ops func_ctx.basic_op_ctxs = basic_op_ctxs - func_ctx.basic_op_num_params = fuser._num_list_basic_op_params - func_ctx.num_extra_inputs = fuser._num_extra_inputs + func_ctx.basic_op_num_params = fuser._basic_op_num_params + func_ctx.num_extra_inputs = fuser.num_extra_inputs func_ctx.num_extra_outputs = len(extra_outputs_flat) - func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() - func_ctx.with_quantized_compute = with_quantized_compute + func_ctx.is_first_module = is_first_module - x.requires_grad_(requires_grad) + # Mark output tensors as not deletable in backward + for tensor in [x] + extra_outputs_flat: + tensor._do_not_clear = True + + x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops) if extra_outputs_flat: return x, *extra_outputs_flat @@ -226,10 +225,7 @@ def backward( basic_op_ctxs = func_ctx.basic_op_ctxs # Restore saved tensors - if func_ctx.with_quantized_compute: - saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors) - else: - saved_tensors = func_ctx.saved_tensors + saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors) # Unflatten list of saved tensors for ctx in basic_op_ctxs: @@ -310,7 +306,6 @@ def backward( dx, # input_ None, # fuser None, # basic_op_kwargs - None, # is_grad_enabled *grad_params_flat, *grad_extra_inputs_flat, ) @@ -323,19 +318,12 @@ class OperationFuser: ---------- ops: list of FusibleOperation Pipeline of operations - fuse_ops: bool - Whether to attempt fusing operations - recipe: Recipe, optional - Quantization recipe to use when fusing and executing operations. - Note: certain fusions may depend on what kind of recipe is being used. """ def __init__( self, ops: list[FusibleOperation], - fuse_ops: bool, - recipe: Optional[Recipe], ) -> None: # Get list of basic operations @@ -349,25 +337,22 @@ def __init__( self._basic_ops: list[BasicOperation] = basic_ops # Number of extra tensor inputs - self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops) + self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops) + self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs) - # Ops for forward and backward pass + # Ops for forward and backward pass, will be populated in fuse_ops self._forward_ops: list[tuple[FusibleOperation, list[int]]] self._backward_ops: list[tuple[FusibleOperation, list[int]]] - self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)] - self._backward_ops = list(reversed(self._forward_ops)) - # Flag for checking if this is the first iteration - self._is_first_forward = True - - # Fuse ops if needed - self.recipe = recipe - if fuse_ops: - self.fuse_ops() + # Cache and detect change of state relevant for fusing operations + self.recipe_type = None + self.first_op_requiring_backward = 0 + self._last_amax_history_len = 0 # Flatten list of parameters - self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()] - self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops] + self._basic_op_params = [list(op.parameters()) for op in self._basic_ops] + self._basic_op_num_params = list(map(len, self._basic_op_params)) + self._flat_basic_op_params = sum(self._basic_op_params, []) @classmethod def _fuse_forward_ops( @@ -380,6 +365,7 @@ def _fuse_forward_ops( ops = fuse_userbuffers_forward_linear(ops) ops = fuse_forward_linear_bias_add(ops) ops = fuse_forward_linear_bias_activation(ops) + ops = fuse_forward_linear_scale_add(ops) return ops @classmethod @@ -392,13 +378,75 @@ def _fuse_backward_ops( if not IS_HIP_EXTENSION: ops = fuse_userbuffers_backward_linear(ops) ops = fuse_backward_linear_add(ops) - ops = fuse_backward_bias_activation(ops, recipe) + ops = fuse_backward_linear_scale(ops) + ops = fuse_backward_activation_bias(ops, recipe) + ops = fuse_backward_add_rmsnorm(ops) return ops - def fuse_ops(self) -> None: - """Attempt to fuse operations""" - self._forward_ops = self._fuse_forward_ops(self._forward_ops, self.recipe) - self._backward_ops = self._fuse_backward_ops(self._backward_ops, self.recipe) + def maybe_fuse_ops( + self, + is_grad_enabled: bool, + recipe: Optional[Recipe], + input_: torch.Tensor, + extra_inputs: list[Iterable[torch.Tensor]], + ): + """Attempt to fuse operations if neccesary""" + + # Determine which basic ops require backward + if not is_grad_enabled: + first_op_requiring_backward = self._num_basic_ops + elif input_.requires_grad: + first_op_requiring_backward = 0 + else: + first_op_requiring_backward = self._num_basic_ops + for op_idx in range(self._num_basic_ops): + op_inputs = itertools.chain(self._basic_op_params[op_idx], extra_inputs[op_idx]) + if any(tensor.requires_grad for tensor in op_inputs): + first_op_requiring_backward = op_idx + break + + # Early exit if fusion parameters haven't changed + need_reset = False + recipe_type = type(recipe) + fusion_params = (recipe_type, first_op_requiring_backward) + if fusion_params != (self.recipe_type, self.first_op_requiring_backward): + # Recipe type or grad requirmenets have changed + need_reset = True + elif ( + recipe is not None + and recipe.delayed() + and self._last_amax_history_len != recipe.amax_history_len + ): + # FP8 delayed scaling has changed amax history length + need_reset = True + if not need_reset: + return + + # Reset recipe state + for op in self._basic_ops: + op.reset_recipe_state(recipe=recipe) + + # Check if this is the first iteration + if self.recipe_type is None: + for op in self._basic_ops: + op.pre_first_fuser_forward() + + # Prepare basic op lists for fusions + forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)] + backward_ops = list(reversed(forward_ops[first_op_requiring_backward:])) + + # Fuse ops + self._forward_ops = self._fuse_forward_ops(forward_ops, recipe) + self._backward_ops = self._fuse_backward_ops(backward_ops, recipe) + + # Save current fusion params + self.recipe_type, self.first_op_requiring_backward = fusion_params + + # Save amax history length + if isinstance(recipe, DelayedScaling): + self._last_amax_history_len = recipe.amax_history_len + else: + self._last_amax_history_len = 0 def __call__( self, @@ -407,23 +455,32 @@ def __call__( basic_op_kwargs: Optional[list[dict[str, Any]]] = None, ) -> torch.Tensor | tuple[torch.Tensor, ...]: # Verify extra input count - if len(extra_inputs) != self._num_extra_inputs: + if len(extra_inputs) != self.num_extra_inputs: raise ValueError( - f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}" + f"Expected {self.num_extra_inputs} extra inputs but got {len(extra_inputs)}" ) - # Initialization before forward pass - if self._is_first_forward: - for op in self._basic_ops: - op.pre_first_forward(recipe=self.recipe) - self._is_first_forward = False - # Canonicalize op kwargs if basic_op_kwargs is None: basic_op_kwargs = [{}] * self._num_basic_ops - # Fuser forward pass + # Unflatten list of extra tensor inputs + extra_inputs_copy = list(extra_inputs) + basic_op_extra_inputs = [] + for op in self._basic_ops: + xs, extra_inputs_copy = _split_tuple(extra_inputs_copy, op.num_extra_inputs) + basic_op_extra_inputs.append(xs) + + # Get environment state + recipe = None + if FP8GlobalStateManager.is_fp8_enabled(): + recipe = FP8GlobalStateManager.get_fp8_recipe() is_grad_enabled = torch.is_grad_enabled() + + # Attempt to fuse operations if neccesary + self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs) + + # Fuser forward pass if is_grad_enabled: forward_func = _OperationFuserAutogradFunction.apply args = [] @@ -434,8 +491,7 @@ def __call__( input, self, basic_op_kwargs, - is_grad_enabled, - *self._basic_op_params, + *self._flat_basic_op_params, *extra_inputs, ) return forward_func(*args) diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index 8ed2702a7..325126a3d 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -6,7 +6,7 @@ from __future__ import annotations from collections.abc import Callable -from typing import Optional +from typing import Any, Optional import torch @@ -54,7 +54,8 @@ class Linear(FusedOperation): weight's `main_grad` attribute instead of relying on PyTorch autograd. The weight's `main_grad` must be set externally and there is no guarantee that `grad` will be set or be - meaningful. + meaningful. This is primarily intented to integrate with + Megatron-LM. """ @@ -91,6 +92,8 @@ def __init__( # Construct basic ops ops = [] + linear_idx = None + bias_idx = None linear_kwargs = { "in_features": in_features, "out_features": out_features, @@ -111,14 +114,16 @@ def __init__( } if tensor_parallel_mode == "row": # Row TP: GEMM + bias + reduction + linear_idx = len(ops) linear_kwargs["in_features"] = local_in_features linear_kwargs["out_features"] = local_out_features linear_kwargs["tensor_parallel_mode"] = None linear_kwargs["tensor_parallel_group"] = None linear_kwargs["sequence_parallel"] = False - bias_kwargs["size"] *= tensor_parallel_size ops.append(BasicLinear(**linear_kwargs)) if bias: + bias_idx = len(ops) + bias_kwargs["size"] *= tensor_parallel_size ops.append(Bias(**bias_kwargs)) if sequence_parallel: ops.append(ReduceScatter(tensor_parallel_group)) @@ -126,45 +131,81 @@ def __init__( ops.append(AllReduce(tensor_parallel_group)) else: # Column TP or no TP: (gather + GEMM) + bias + linear_idx = len(ops) ops.append(BasicLinear(**linear_kwargs)) if bias: + bias_idx = len(ops) ops.append(Bias(**bias_kwargs)) # Initialize base class super().__init__(ops) - self._has_bias: bool = bias + # Register parameters + self._linear_idx: Optional[int] = linear_idx + self._bias_idx: Optional[int] = bias_idx + self.register_parameter("weight", self.basic_ops[self._linear_idx].weight) + bias = None + if self._bias_idx is not None: + bias = self.basic_ops[self._bias_idx].bias + self.register_parameter("bias", bias) - @property - def weight(self) -> torch.nn.Parameter: - """Weight tensor + def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: + """Add a parameter to the module - Parameter is owned by `BasicLinear` operation. + Also updates the basic operation that owns the parameter. """ - return self.basic_ops[0].weight - - @weight.setter - def weight(self, value: Optional[torch.nn.Parameter]) -> None: - self.basic_ops[0].weight = value - - @property - def bias(self) -> Optional[torch.nn.Parameter]: - """Bias tensor - - Parameter is owned by `Bias` operation. - - """ - if self._has_bias: - return self.basic_ops[1].bias - return None - - @bias.setter - def bias(self, value: Optional[torch.nn.Parameter]) -> None: - if self._has_bias: - self.basic_ops[1].bias = value - elif value is not None: + if name == "bias" and self._bias_idx is None and param is not None: raise ValueError( "Attempted to set bias parameter in Linear operation " "that does not have bias enabled" ) + super().register_parameter(name, param) + if name == "weight": + self.basic_ops[self._linear_idx].weight = param + elif name == "bias" and self._bias_idx is not None: + self.basic_ops[self._bias_idx].bias = param + + def state_dict(self, *, prefix: str = "", **kwargs) -> dict[str, Any]: + """Save state""" + state_dict = super().state_dict(prefix=prefix, **kwargs) + + # Remove basic op params from state dict + # Note: Logically, basic ops own params and fused ops are + # considered as stateless. However, we register weight and + # bias params in the linear op for convenience. We remove + # these redudant params from the checkpoint for backward + # compatibility. + if f"{prefix}weight" in state_dict: + del state_dict[f"{prefix}weight"] + if f"{prefix}bias" in state_dict: + del state_dict[f"{prefix}bias"] + + return state_dict + + def _load_from_state_dict( + self, + state_dict: dict[str, Any], + prefix: str, + *args, + **kwargs, + ) -> None: + + # Add basic op params to state dict + # Note: Logically, basic ops own params and fused ops are + # considered as stateless. However, we register weight and + # bias params in the linear op for convenience. We remove + # these redudant params from the checkpoint for backward + # compatibility. + if f"{prefix}weight" not in state_dict: + state_dict[f"{prefix}weight"] = state_dict[ + f"{prefix}basic_ops.{self._linear_idx}.weight" + ] + if f"{prefix}bias" not in state_dict: + if self._bias_idx is None: + state_dict[f"{prefix}bias"] = None + else: + state_dict[f"{prefix}bias"] = state_dict[f"{prefix}basic_ops.{self._bias_idx}.bias"] + + # Load state dict + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 8490019e5..903bc49d5 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -15,9 +15,6 @@ from transformer_engine.common.recipe import Recipe from ..fp8 import ( - MXFP8BlockScalingRecipeState, - DelayedScalingRecipeState, - Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, fp8_autocast, @@ -65,18 +62,14 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): def is_fused_op(self) -> bool: """Whether this op is the fusion of one or more basic ops""" - def pre_first_forward( - self, - *, - recipe: Optional[Recipe], - ) -> None: - """Preprocessing before forward pass""" + def pre_first_fuser_forward(self) -> None: + """Preprocessing before first fuser forward pass""" def get_input_quantizer(self) -> Optional[Quantizer]: """Get builder class for quantized input tensor""" - def get_grad_input_quantizer(self) -> Optional[Quantizer]: - """Get builder class for quantized input's grad tensor""" + def get_grad_output_quantizer(self) -> Optional[Quantizer]: + """Get builder class for quantized output's grad tensor""" def fuser_forward( self, @@ -84,7 +77,7 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: @@ -104,8 +97,8 @@ def fuser_forward( Input tensor basic_op_extra_inputs: list of torch.Tensor Extra tensor inputs to basic operations - prev_op_grad_input_quantizer: Quantizer, optional - The grad_input_quantizer of the preceeding operation + prev_op_grad_output_quantizer: Quantizer, optional + The grad_output_quantizer of the preceeding operation next_op_input_quantizer: Quantizer, optional The input_quantizer of the following operation basic_op_kwargs: list of dict @@ -186,8 +179,11 @@ def __init__(self) -> None: super().__init__() # Objects for quantization - self._quantizers: Optional[dict[str, list[Quantizer]]] = None self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None + self._quantizers: Optional[dict[str, list[Quantizer]]] = None + with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None + self.reset_recipe_state(recipe=recipe) @property def is_fused_op(self) -> bool: @@ -214,120 +210,141 @@ def get_input_quantizer(self) -> Optional[Quantizer]: return self.get_quantizer("forward", 0) return None - def get_grad_input_quantizer(self) -> Optional[Quantizer]: + def get_grad_output_quantizer(self) -> Optional[Quantizer]: if self.num_quantizers("backward") > 0: return self.get_quantizer("backward", 0) return None - def _reset_quantization_recipe_state( + def reset_recipe_state( self, *, - recipe: Recipe, + recipe: Optional[Recipe], ) -> None: """Construct state for quantization recipe""" - # Quantization recipe state for forward and backward pass - self._fp8_metas = {"forward": None, "backward": None} - self._quantizers = {"forward": [], "backward": []} - for mode in ("forward", "backward"): - num_quantizers = self.num_quantizers(mode) - if num_quantizers == 0: - continue - - if recipe.float8_block_scaling(): - raise NotImplementedError( - "Fusible operations do not support FP8 block scaling recipe" - ) - - # Construct quantization recipe state - recipe_state = RecipeState.create( - recipe, - mode=mode, - num_quantizers=num_quantizers, - ) - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=(mode == "forward"), - ) - self._fp8_metas[mode] = { - fp8_meta_key: recipe_state, - "recipe": recipe, - "fp8_group": FP8GlobalStateManager.get_fp8_group(), - } - - # Construct builder class for quantized tensors - self._quantizers[mode] = recipe_state.make_quantizers() + # Clear quantization state if necessary + if recipe is None: + self._fp8_metas = None + self._quantizers = None + return - def _update_quantization_recipe_state( - self, - *, - recipe: Recipe, - ) -> None: - """Make sure quantizer state matches quantization recipe""" + # Communication group for FP8 amax reductions + fp8_group = FP8GlobalStateManager.get_fp8_group() - # Reset quantization state if needed + # Skip resetting recipe type if it did not actually change. + # This could happen for example if calling BasicOperation.forward directly, as in that + # case, the OperationFuser is not persistent, or when loading from a checkpoint + need_to_reset_recipe_state = False if self._fp8_metas is None or self._quantizers is None: - self._reset_quantization_recipe_state(recipe=recipe) - return - for mode in ("forward", "backward"): - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=(mode == "forward"), - ) - if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]: - continue - recipe_state = self._fp8_metas[mode][fp8_meta_key] - need_to_reset_recipe_state = ( - (recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState)) - or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) - or ( - recipe.float8_block_scaling() - and not isinstance(recipe_state, Float8BlockScalingRecipeState) + need_to_reset_recipe_state = True + else: + for mode in ("forward", "backward"): + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), ) - ) - if need_to_reset_recipe_state: - self._reset_quantization_recipe_state(recipe=recipe) - return + if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]: + continue + recipe_state = self._fp8_metas[mode][fp8_meta_key] + if not isinstance(recipe, type(recipe_state.recipe)): + need_to_reset_recipe_state = True + break + + if need_to_reset_recipe_state: + # Construct quantization recipe states + self._fp8_metas = {"forward": None, "backward": None} + self._quantizers = {"forward": [], "backward": []} + for mode in ("forward", "backward"): + num_quantizers = self.num_quantizers(mode) + if num_quantizers == 0: + continue - # Quantization recipe state for forward and backward pass - for mode in ("forward", "backward"): - num_quantizers = self.num_quantizers(mode) - if num_quantizers == 0: - continue + if recipe.float8_block_scaling(): + raise NotImplementedError( + "Fusible operations do not support FP8 block scaling recipe" + ) - # Update FP8 metadata - fp8_meta = self._fp8_metas[mode] - fp8_meta["recipe"] = recipe - fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + # Construct quantization recipe state + recipe_state = RecipeState.create( + recipe, + mode=mode, + num_quantizers=num_quantizers, + ) + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + self._fp8_metas[mode] = { + fp8_meta_key: recipe_state, + "recipe": recipe, + "fp8_group": fp8_group, + } - # Get recipe state - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=(mode == "forward"), - ) - recipe_state = fp8_meta[fp8_meta_key] + # Construct builder class for quantized tensors + self._quantizers[mode] = recipe_state.make_quantizers() + else: + # Update quantization recipe states + for mode in ("forward", "backward"): + if self._fp8_metas[mode] is None: + continue + self._fp8_metas[mode]["recipe"] = recipe + self._fp8_metas[mode]["fp8_group"] = fp8_group - # Reallocate amax history if needed - if not recipe.delayed(): - continue + # Update amax history for FP8 delayed scaling + if recipe.delayed(): + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + recipe_state = self._fp8_metas[mode][fp8_meta_key] - current_length = recipe_state.amax_history.size(0) - target_length = recipe.amax_history_len - if current_length != target_length: - with torch.no_grad(): + # Reallocate amax history if needed + current_length = recipe_state.amax_history.size(0) + target_length = recipe.amax_history_len if target_length < current_length: - recipe_state.amax_history = recipe_state.amax_history[ - :target_length - ].clone() - else: - recipe_state.amax_history = torch.nn.functional.pad( - recipe_state.amax_history, - pad=(0, 0, 0, target_length - current_length), - ) - self._quantizers[mode] = recipe_state.make_quantizers() + with torch.no_grad(): + recipe_state.amax_history = recipe_state.amax_history[ + :target_length + ].clone() + elif target_length > current_length: + with torch.no_grad(): + recipe_state.amax_history = torch.nn.functional.pad( + recipe_state.amax_history, + pad=(0, 0, 0, target_length - current_length), + ) + + # Update quantizers with new amax pointers + self._quantizers[mode] = recipe_state.make_quantizers() + + # Update the global buffers with new amax pointers + if FP8GlobalStateManager.get_buffer_info() in self._fp8_metas[mode]: + pos, buffer_key = self._fp8_metas[mode][ + FP8GlobalStateManager.get_buffer_info() + ] + if buffer_key in FP8GlobalStateManager.global_amax_buffer: + assert ( + buffer_key in FP8GlobalStateManager.global_amax_history_buffer + ), "TE internal error during amax history change." + FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = ( + recipe_state.amax_history[0] + ) + FP8GlobalStateManager.global_amax_history_buffer[buffer_key][ + pos + ] = recipe_state.amax_history + + # Add meta tensors to global buffer to participate in reduction + for mode in ("forward", "backward"): + if ( + FP8GlobalStateManager.is_fp8_enabled() + and self.num_quantizers(mode) + and not FP8GlobalStateManager.fp8_graph_capturing() + ): + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + self._fp8_metas[mode], + ) def get_quantizer( self, mode: str, index: int, - ) -> Quantizer: + ) -> Optional[Quantizer]: """Get builder class for quantized tensor Parameters @@ -337,7 +354,7 @@ def get_quantizer( """ if self._quantizers is None: - self._reset_quantization_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe()) + return None return self._quantizers[mode][index] @torch.no_grad() @@ -388,33 +405,13 @@ def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) - def pre_first_forward( - self, - *, - recipe: Optional[Recipe], - ) -> None: - """Preprocessing before forward pass""" - - # Initialize FP8 metadata if needed - if recipe is not None: - self._update_quantization_recipe_state(recipe=recipe) - if not FP8GlobalStateManager.fp8_graph_capturing(): - if self.num_quantizers("forward"): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self._fp8_metas["forward"], - ) - if self.num_quantizers("backward"): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self._fp8_metas["backward"], - ) - @abc.abstractmethod def op_forward( self, ctx: OperationContext, input_: torch.Tensor, *, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], **kwargs: Any, ) -> torch.Tensor: @@ -426,8 +423,8 @@ def op_forward( Context to coordinate between forward and backward passes input_: torch.Tensor Input tensor - prev_op_grad_input_quantizer: Quantizer, optional - The grad_input_quantizer of the preceeding operation + prev_op_grad_output_quantizer: Quantizer, optional + The grad_output_quantizer of the preceeding operation next_op_input_quantizer: Quantizer, optional The input_quantizer of the following operation @@ -468,7 +465,7 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, list[tuple[()]]]: @@ -482,7 +479,7 @@ def fuser_forward( output = self.op_forward( basic_op_ctxs[0], input_, - prev_op_grad_input_quantizer=prev_op_grad_input_quantizer, + prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, next_op_input_quantizer=next_op_input_quantizer, **basic_op_kwargs[0], ) @@ -518,9 +515,7 @@ def forward( """Apply operation""" from .fuser import OperationFuser - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None - return OperationFuser([self], fuse_ops=False, recipe=recipe)( + return OperationFuser([self])( input, *extra_inputs, basic_op_kwargs=[kwargs], @@ -630,7 +625,7 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: # Get op's quantizer state, initializing if needed if self._fp8_metas is None or self._fp8_metas[mode] is None: with fp8_autocast(fp8_recipe=state[mode]["recipe"]): - self._reset_quantization_recipe_state(recipe=state[mode]["recipe"]) + self.reset_recipe_state(recipe=state[mode]["recipe"]) fp8_meta = self._fp8_metas[mode] # Load extra items @@ -708,13 +703,12 @@ def is_fused_op(self) -> bool: def get_input_quantizer(self) -> Optional[Quantizer]: return self.basic_ops[0].get_input_quantizer() - def get_grad_input_quantizer(self) -> Optional[Quantizer]: - return self.basic_ops[-1].get_grad_input_quantizer() + def get_grad_output_quantizer(self) -> Optional[Quantizer]: + return self.basic_ops[-1].get_grad_output_quantizer() - def pre_first_forward(self, *args, **kwargs) -> None: - """Preprocessing before forward pass""" + def pre_first_fuser_forward(self) -> None: for op in self.basic_ops: - op.pre_first_forward(*args, **kwargs) + op.pre_first_fuser_forward() def forward( self, @@ -727,9 +721,7 @@ def forward( basic_op_kwargs = [{} for _ in range(len(self.basic_ops))] from .fuser import OperationFuser - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None - return OperationFuser([self], fuse_ops=False, recipe=recipe)( + return OperationFuser([self])( input, *extra_inputs, basic_op_kwargs=basic_op_kwargs, diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py index f18678309..2afda58e4 100644 --- a/transformer_engine/pytorch/ops/sequential.py +++ b/transformer_engine/pytorch/ops/sequential.py @@ -10,7 +10,6 @@ import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe from transformer_engine.pytorch.ops.op import FusibleOperation from transformer_engine.pytorch.ops.fuser import OperationFuser @@ -147,7 +146,6 @@ def __add__(self, modules: Iterable[torch.nn.Modules]) -> Sequential: def _make_module_groups( cls, modules: Iterable[torch.nn.Module], - recipe: Optional[Recipe], ) -> list[OperationFuser | torch.nn.Module]: """Make list of modules, with fusible operations grouped together""" @@ -162,24 +160,7 @@ def _make_module_groups( groups.append(module) for idx, group in enumerate(groups): if isinstance(group, list): - groups[idx] = OperationFuser(group, fuse_ops=True, recipe=recipe) - - # Check if operations expect extra input or output tensors - # Note: If any op has extra inputs or outputs, then the entire - # Sequential must be made up of TE ops. - if len(groups) > 1: - ops = [] - for group in groups: - if isinstance(group, OperationFuser): - ops.extend(group._basic_ops) - num_extra_inputs = sum(op.num_extra_inputs for op in ops) - num_extra_outputs = sum(op.num_extra_outputs for op in ops) - if num_extra_inputs > 0 or num_extra_outputs > 0: - raise RuntimeError( - f"`Sequential` expects {num_extra_inputs} extra inputs " - f"and {num_extra_outputs} extra outputs, " - "but it contains non-fusible operations" - ) + groups[idx] = OperationFuser(group) return groups @@ -190,22 +171,28 @@ def forward( ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass""" - # Get current global state - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None - global_state = (with_quantized_compute, type(recipe)) - - # Reset module groups is global state changed - if self._last_global_state != global_state: - self._module_groups = None - self._last_global_state = global_state - # Create module groups if needed if self._module_groups is None: - self._module_groups = self._make_module_groups(self._modules.values(), recipe) + self._module_groups = self._make_module_groups(self._modules.values()) # Forward pass for each module group x = input + extra_outputs: list[torch.Tensor] = [] for module_group in self._module_groups: - x = module_group(x, *extra_inputs) + if isinstance(module_group, OperationFuser): + xs, extra_inputs = ( + (x,) + extra_inputs[: module_group.num_extra_inputs], + extra_inputs[module_group.num_extra_inputs :], + ) + xs = module_group(*xs) + if isinstance(xs, tuple): + x, ys = xs[0], xs[1:] + extra_outputs.extend(ys) + else: + x = xs + else: + x = module_group(x) + + if extra_outputs: + return (x,) + tuple(extra_outputs) return x diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 86b79eaad..e86873b12 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -12,14 +12,30 @@ import os import shutil from pathlib import Path - +import platform +import urllib import setuptools +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel +from packaging.version import parse try: + import torch from torch.utils.cpp_extension import BuildExtension except ImportError as e: raise RuntimeError("This package needs Torch to build.") from e +FORCE_BUILD = os.getenv("NVTE_PYTORCH_FORCE_BUILD", "FALSE") == "TRUE" +FORCE_CXX11_ABI = os.getenv("NVTE_PYTORCH_FORCE_CXX11_ABI", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("NVTE_PYTORCH_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +PACKAGE_NAME = "transformer_engine_torch" +BASE_WHEEL_URL = ( + "https://github.com/NVIDIA/TransformerEngine/releases/download/{tag_name}/{wheel_name}" +) +# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as +# torch._C._GLIBCXX_USE_CXX11_ABI +# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 +if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True current_file_path = Path(__file__).parent.resolve() build_tools_dir = current_file_path.parent.parent / "build_tools" @@ -34,13 +50,94 @@ from build_tools.utils import ( rocm_build, copy_common_headers, copy_hipify_tools, clear_hipify_tools_copy ) from build_tools.te_version import te_version -from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements +from build_tools.pytorch import ( + setup_pytorch_extension, + install_requirements, + test_requirements, +) os.environ["NVTE_PROJECT_BUILDING"] = "1" CMakeBuildExtension = get_build_ext(BuildExtension, True) +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith("linux"): + return f"linux_{platform.uname().machine}" + if sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + if sys.platform == "win32": + return "win_amd64" + + raise ValueError(f"Unsupported platform: {sys.platform}") + + +def get_wheel_url(): + """Construct the wheel URL for the current platform.""" + torch_version_raw = parse(torch.__version__) + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + nvte_version = te_version() + torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" + cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() + + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build torch, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + torch_cuda_version = parse(torch.version.cuda) + # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 + # to save CI time. Minor versions should be compatible. + torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + cuda_version = f"{torch_cuda_version.major}" + + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f"{PACKAGE_NAME}-{nvte_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" + + wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{nvte_version}", wheel_name=wheel_filename) + + return wheel_url, wheel_filename + + +class CachedWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all grouped gemm installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. + """ + + def run(self): + if FORCE_BUILD: + super().run() + + wheel_url, wheel_filename = get_wheel_url() + print("Guessing wheel URL: ", wheel_url) + try: + urllib.request.urlretrieve(wheel_url, wheel_filename) + + # Make the archive + # Lifted from the root wheel processing command + # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + print("Raw wheel path", wheel_path) + os.rename(wheel_filename, wheel_path) + except (urllib.error.HTTPError, urllib.error.URLError): + print("Precompiled wheel not found. Building from source...") + # If the wheel could not be downloaded, build from source + super().run() + + if __name__ == "__main__": # Extensions common_headers_dir = "common_headers" @@ -55,11 +152,11 @@ # Configure package setuptools.setup( - name="transformer_engine_torch", + name=PACKAGE_NAME, version=te_version(), description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension}, + cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand}, install_requires=install_requirements(), tests_require=test_requirements(), ) diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 882650ffb..da0220eb7 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -59,7 +59,7 @@ def __new__( instance = super().__new__(cls, *args, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data - instance._quantizer = quantizer + instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv @@ -124,9 +124,15 @@ def restore_from_saved( self._columnwise_scale_inv = tensors[3] return tensors[4:] - def get_data_tensors(self): + def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): """Get this Tensor's data.""" - return self._rowwise_data, self._columnwise_data + if rowwise_data and columnwise_data: + return self._rowwise_data, self._columnwise_data + if rowwise_data: + return self._rowwise_data + if columnwise_data: + return self._columnwise_data + raise ValueError("No data to get, both rowwise_data and columnwise_data are False") def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor: """Takes dequantized columnwise data and permutes to a rowwise shape""" @@ -343,9 +349,14 @@ def _create_columnwise(self): def _transpose_columnwise_data(self): """Plainly transpose the columnwise data and scale inv.""" if self._columnwise_data is not None: + # TODO(yuzhongw, tmoon): Figure out why _old_data is not automatically + # deallocated by GC. Manually deallocating is a temporary hack. + _old_data = self._columnwise_data self._columnwise_data = tex.fp8_transpose( self._columnwise_data, self._fp8_dtype, out=None ) + _old_data.data = _empty_tensor() + del _old_data def __repr__(self): if self._rowwise_data is not None: diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index c0dc6e651..6d4822344 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -86,7 +86,7 @@ def __new__( else: instance = super().__new__(cls, *args, **kwargs) instance._data = data - instance._quantizer = quantizer + instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype instance._scale_inv = fp8_scale_inv instance._transpose = data_transpose @@ -95,8 +95,13 @@ def __new__( return instance def clear(self): - """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" - for t in (self._data, self._transpose, self._scale_inv): + """Deallocate this tensor's memory. Typically not needed and must be used carefully. + + Scale-inv tensor is not deallocated because it's often shared + between multiple FP8 tensors. + + """ + for t in (self._data, self._transpose): if t is not None: t.data = _empty_tensor() self._transpose_invalid = True @@ -128,9 +133,15 @@ def restore_from_saved( self._scale_inv = tensors[2] return tensors[3:] - def get_data_tensors(self): + def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): """Get this Tensor's data.""" - return self._data, self._transpose + if rowwise_data and columnwise_data: + return self._data, self._transpose + if rowwise_data: + return self._data + if columnwise_data: + return self._transpose + raise ValueError("No data to get, both rowwise_data and columnwise_data are False") def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Dequantize to a higher precision.""" diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index 153732000..11055c4cc 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -91,7 +91,7 @@ def __new__( instance = super().__new__(cls, *args, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data - instance._quantizer = quantizer + instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv @@ -144,9 +144,15 @@ def restore_from_saved( self._columnwise_scale_inv = tensors[3] return tensors[4:] - def get_data_tensors(self): + def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): """Get this Tensor's data.""" - return self._rowwise_data, self._columnwise_data + if rowwise_data and columnwise_data: + return self._rowwise_data, self._columnwise_data + if rowwise_data: + return self._rowwise_data + if columnwise_data: + return self._columnwise_data + raise ValueError("No data to get, both rowwise_data and columnwise_data are False") def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Dequantize to a higher precision.""" diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index bac715949..0e41fc9c5 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -521,7 +521,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst._rowwise_data = src._rowwise_data dst._columnwise_data = src._columnwise_data - dst._quantizer = src._quantizer + dst._quantizer = src._quantizer.copy() dst._fp8_dtype = src._fp8_dtype dst._rowwise_scale_inv = src._rowwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 742096dca..a0a17d1a1 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -119,10 +119,9 @@ def make_empty( # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: - inner_dim = data.size(-1) + transpose_shape = [data.size(-1)] + list(data.shape[:-1]) data_transpose = torch.empty( - inner_dim, - data.numel() // inner_dim, + transpose_shape, dtype=torch.uint8, device=device, ) @@ -189,13 +188,19 @@ def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: """Function using primitives with ONNX defined translations.""" - out = torch.ops.tex.fp8_dequantize(tensor._data, self.scale.item()) + out = torch.ops.tex.fp8_dequantize(tensor._data, tensor._scale_inv) out = out.to(tensor.dtype) return out def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return DelayedScaling + def supports_only_rowwise_all_gather(self) -> bool: + """ + Float8Quantizer supports only rowwise all-gather + """ + return True + class Float8CurrentScalingQuantizer(Quantizer): """Builder class for FP8 tensors with per-tensor current scaling @@ -241,7 +246,7 @@ def __init__( amax_epsilon: float = 0.0, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - self.scale = torch.ones(1, dtype=torch.float32, device=device) + self.scale = torch.empty(1, dtype=torch.float32, device=device) self.amax = torch.empty(1, dtype=torch.float32, device=device) self.dtype = fp8_dtype self.with_amax_reduction = with_amax_reduction @@ -361,15 +366,25 @@ def create_tensor_from_data( def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: """Function using primitives with ONNX defined translations.""" - raise NotImplementedError( - "Float8CurrentScalingQuantizer does not support ONNX quantization yet." + if tensor.dtype != torch.float32: + tensor = tensor.to(torch.float32) + data, scale_inv = torch.ops.tex.fp8_cs_quantize(tensor) + return Float8Tensor( + shape=data.shape, + dtype=torch.float32, + data=data, + fp8_scale_inv=scale_inv, + fp8_dtype=self.dtype, + requires_grad=False, + data_transpose=None, + quantizer=self, ) def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: """Function using primitives with ONNX defined translations.""" - raise NotImplementedError( - "Float8CurrentScalingQuantizer does not support ONNX dequantization yet." - ) + out = torch.ops.tex.fp8_dequantize(tensor._data, tensor._scale_inv) + out = out.to(tensor.dtype) + return out def _canonicalized_amax_reduction_group(self) -> dist_group_type: """Get process group for amax reduction""" @@ -378,6 +393,12 @@ def _canonicalized_amax_reduction_group(self) -> dist_group_type: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return Float8CurrentScaling + def supports_only_rowwise_all_gather(self) -> bool: + """ + Float8CurrentScalingQuantizer supports only rowwise all-gather + """ + return True + class Float8Tensor(Float8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data @@ -706,7 +727,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Float8Tensor attributes self._data = tensor._data - self._quantizer = tensor._quantizer + self._quantizer = tensor._quantizer.copy() self._fp8_dtype = tensor._fp8_dtype self._scale_inv = tensor._scale_inv self._transpose = tensor._transpose diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index b3504b175..16b1568cb 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -110,6 +112,7 @@ def make_empty( # Allocate FP8 data data = torch.empty(shape, dtype=torch.uint8, device=device) + # ROCm TE does not implement fuse padding zeros so use zero tensor here scale_inv = torch.zeros( round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), @@ -122,6 +125,7 @@ def make_empty( columnwise_scale_inv = None if self.columnwise_usage: columnwise_data = torch.empty_like(data) + # ROCm TE does not implement fuse padding zeros so use zero tensor here columnwise_scale_inv = torch.zeros( round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(shape[-1], 128), @@ -443,7 +447,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data - self._quantizer = tensor._quantizer + self._quantizer = tensor._quantizer.copy() self._fp8_dtype = tensor._fp8_dtype self._rowwise_scale_inv = tensor._rowwise_scale_inv self._columnwise_scale_inv = tensor._columnwise_scale_inv diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 3a6eb7290..2f634f399 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -264,6 +264,10 @@ def onnx_dequantize(self, tensor) -> torch.Tensor: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: """Returns recipe class that is compatible with this quantizer""" + def supports_only_rowwise_all_gather(self) -> bool: + """Returns True if the quantizer supports only rowwise all-gather""" + return False + class _QuantizeFunc(torch.autograd.Function): """Cast to FP8 from other dtype""" diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index b9d59f496..89e43f845 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -175,7 +175,8 @@ class TransformerLayer(torch.nn.Module): if set to `False`, the transformer layer will not learn any additive biases. activation : str, default = 'gelu' Type of activation used in MLP block. - Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'. + Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', + 'silu', and 'swiglu'. device : Union[torch.device, str], default = "cuda" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the @@ -236,14 +237,23 @@ class TransformerLayer(torch.nn.Module): parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument `fuse_wgrad_accumulation`. - use_qk_norm: bool, default = 'False' - if set to `True`, L2 normalization is applied to query and key tensors - after RoPE (if applicable) but before attention computation. - This follows the Llama4 approach for QK normalization to improve - training stability and model performance. + qk_norm_type: Optional[str], default = None + type of normalization to apply to query and key tensors. + Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied. + When 'L2Normalization', L2 normalization is applied to query and key tensors. + When 'RMSNorm', RMS normalization is applied to query and key tensors. + When 'LayerNorm', layer normalization is applied to query and key tensors. + Normalization is applied after RoPE (if applicable) but before attention computation + when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach for + QK normalization to improve training stability and model performance. qk_norm_eps: float, default = 1e-6 - epsilon value for L2 normalization of query and key tensors. - Only used when `use_qk_norm` is True. + epsilon value for normalization of query and key tensors. + Only used when `qk_norm_type` is not None. + qk_norm_before_rope: bool, default = `False` + if set to `True`, query and key normalization is applied before rotary position + embedding. When `False` (default), normalization is applied after RoPE. + This parameter allows supporting different architectural variants that apply + QK normalization at different points. """ def __init__( @@ -293,8 +303,9 @@ def __init__( device: Union[torch.device, str] = "cuda", attn_input_format: str = "sbhd", name: str = None, - use_qk_norm: bool = False, + qk_norm_type: Optional[str] = None, qk_norm_eps: float = 1e-6, + qk_norm_before_rope: bool = False, ) -> None: super().__init__() @@ -397,8 +408,9 @@ def __init__( return_bias=not self.parallel_attention_mlp, normalization=normalization, device=device, - use_qk_norm=use_qk_norm, + qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, + qk_norm_before_rope=qk_norm_before_rope, name=name + ".self_attention" if name is not None else None, ) @@ -413,8 +425,9 @@ def __init__( return_bias=True, normalization=normalization, device=device, - use_qk_norm=use_qk_norm, + qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, + qk_norm_before_rope=qk_norm_before_rope, name=name + ".inter_attention" if name is not None else None, ) diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 20e0b737d..6eda84b91 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -236,6 +236,7 @@ def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, + grad_output_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): @@ -258,6 +259,7 @@ def element_mul_kernel( X_ptr += program_id * X_stride # Load the gradient output value + grad_output_ptr += program_id * grad_output_stride grad_output = tl.load(grad_output_ptr) # Perform the element-wise multiplication @@ -346,13 +348,17 @@ def cross_entropy_forward( return loss, _input -def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): +def cross_entropy_backward( + _input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False +): """Backward implementation of cross entropy loss kernel""" # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time - if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + # Only check torch.equal when not in CUDA graph capturable mode + if not is_cg_capturable and torch.equal( + grad_output, torch.tensor(1.0, device=grad_output.device) + ): pass - else: B, SQ, V = _input.shape n_rows = B * SQ @@ -362,6 +368,7 @@ def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): _input, _input.stride(-2), grad_output, + 1 if grad_output.numel() > 1 else 0, V, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS, diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 9ce01362f..ceb88108f 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -359,7 +359,7 @@ def _permute_kernel( if prob == 0.0: # for routing_map padding # dst_row != -1 and prob == 0.0 means that this slot is padded - tl.store(output_ptr + output_off, 0, mask=mask) + tl.store(output_ptr + output_off, 0.0, mask=mask) else: tl.store(output_ptr + output_off, inp, mask=mask) else: diff --git a/transformer_engine/pytorch/triton_kernels/cast.py b/transformer_engine/pytorch/triton_kernels/cast.py index 4c7033132..b6a7270a3 100644 --- a/transformer_engine/pytorch/triton_kernels/cast.py +++ b/transformer_engine/pytorch/triton_kernels/cast.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information """Python interface for cast extensions""" @@ -115,6 +115,7 @@ def te_quantize_triton( ) else: + out.remove_caches() #Make sure to remove transpose if it is marked as invalid out = tex.quantize(input_tensor, quantizer, out, noop_flag) elif isinstance(out, MXFP8TensorBase): te_cast_transpose_mxfp8_triton(input_tensor, out) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 9d0d71fdc..d124fbeaf 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -44,10 +44,10 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: for t in tensors: if t is not None: # Workaround for double buffering in cpu offload - if hasattr(t, "do_not_clear"): + if hasattr(t, "_do_not_clear"): continue if hasattr(t, "get_data_tensors"): - if any(hasattr(tensor, "do_not_clear") for tensor in t.get_data_tensors()): + if any(hasattr(tensor, "_do_not_clear") for tensor in t.get_data_tensors()): continue if hasattr(t, "clear"):